From 83f8996c176e705784e982891805fd2163901a45 Mon Sep 17 00:00:00 2001 From: James Thornton Date: Thu, 30 Jun 2022 18:33:32 +0200 Subject: [PATCH 01/46] add sorting, gaus initializers, add gaus helpers to tools --- ott/core/initializers.py | 140 +++++++++++++++++++++++ ott/tools/gaussian_mixture/gaussian.py | 36 ++++++ ott/tools/gaussian_mixture/scale_tril.py | 28 ++++- 3 files changed, 198 insertions(+), 6 deletions(-) create mode 100644 ott/core/initializers.py diff --git a/ott/core/initializers.py b/ott/core/initializers.py new file mode 100644 index 000000000..e168f0795 --- /dev/null +++ b/ott/core/initializers.py @@ -0,0 +1,140 @@ +# Copyright 2022 The OTT Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sinkhorn initializers.""" +import functools +import jax +from jax import numpy as jnp + +from .linear_problems import LinearProblem +from ..tools.gaussian_mixture.gaussian import Gaussian +from ..geometry.pointcloud import PointCloud + +@jax.tree_util.register_pytree_node_class +class SinkhornInitializer(): + + def apply(self, linear_problem: LinearProblem) -> jnp.ndarray: + """ + Input: + linear_problem: OT problem between discrete distributions of size n and m + + Return: + dual potential, array of size m + """ + pass + + + +class GaussianInitializer(SinkhornInitializer): + + def __init__(self, stop_gradient=True) -> None: + super().__init__() + + self.stop_gradient = stop_gradient + + + def apply(self, linear_problem: LinearProblem, init_f=None) -> jnp.ndarray: + + + cost_matrix = linear_problem.geom.cost_matrix + if self.stop_gradient: + cost_matrix = jax.lax.stop_gradient(cost_matrix) + + n = cost_matrix.shape[0] + f_potential = jnp.zeros(n) if init_f is None else init_f + + if not isinstance(linear_problem.geom, PointCloud): + return f_potential + + else: + x = linear_problem.geom.x + y = linear_problem.geom.y + gaussian_a = Gaussian.from_samples(x, linear_problem.a) + gaussian_b = Gaussian.from_samples(y, linear_problem.b) + + f_potential = gaussian_a.f_potential(dest=gaussian_b, points=x) + + return f_potential + +class SortingInit(SinkhornInitializer): + + def __init__(self, vector_min=False, tol=1e-2, max_iter=100, stop_gradient=True) -> None: + super().__init__() + + self.tolerance = tol + self.stop_gradient = stop_gradient + self.max_iter = self.max_iter + self.update_fn = self.vectorized_update if vector_min else self.coordinate_update + + def vectorized_update(self, f, modified_cost): + f = jnp.min(modified_cost + f[None, :], axis=1) + return f + + + @jax.jit + def coordinate_update(self, f, modified_cost): + + def body_fn(i, f): + new_f = jnp.min(modified_cost[i, :] + f) + f = f.at[i].set(new_f) + return f + + return jax.lax.fori_loop(0, len(f), body_fn, f) + + @functools.partial(jax.jit, static_argnums=(1, 2, 3)) + def init_sorting_dual(self, modified_cost, f_potential): + it = 0 + diff = self.tolerance + 1.0 + + state = (f_potential, diff, it) + def body_fn(state): + prev_f, _, it = state + f_potential = self.update_fn(prev_f, modified_cost) + diff = jnp.sum((f_potential - prev_f) ** 2) + it += 1 + return f_potential, diff, it + + def cond_fn(state): + _, diff, it = state + return (diff > self.tolerance) & (it < self.mat_iter) + + f_potential, _, it = jax.lax.while_loop(cond_fun=cond_fn, body_fun=body_fn, init_val=state) + + return f_potential + + def apply(self, linear_problem: LinearProblem, init_f=None) -> jnp.ndarray: + + cost_matrix = linear_problem.geom.cost_matrix + if self.stop_gradient: + cost_matrix = jax.lax.stop_gradient(cost_matrix) + + modified_cost = cost_matrix - jnp.diag(cost_matrix)[None, :] + + n = cost_matrix.shape[0] + f_potential = jnp.zeros(n) if init_f is None else init_f + + f_potential = self.init_sorting_dual(modified_cost, f_potential) + + return f_potential + + + + + + + + + + + diff --git a/ott/tools/gaussian_mixture/gaussian.py b/ott/tools/gaussian_mixture/gaussian.py index 967a6abcb..9f85a4f1b 100644 --- a/ott/tools/gaussian_mixture/gaussian.py +++ b/ott/tools/gaussian_mixture/gaussian.py @@ -23,6 +23,9 @@ LOG2PI = math.log(2. * math.pi) +@jax.vmap +def batch_inner_product(x, y): + return x.dot(y) @jax.tree_util.register_pytree_node_class class Gaussian: @@ -31,6 +34,27 @@ class Gaussian: def __init__(self, loc: jnp.ndarray, scale: scale_tril.ScaleTriL): self._loc = loc self._scale = scale + + @classmethod + def from_samples(cls, x:jnp.ndarray, weights: jnp.ndarray = None) -> 'Gaussian': + """Construct a Gaussian from weighted samples + + Args: + x: [n x d] array of samples + weights: [n] array of weights + + Returns: + Gaussian. + """ + + if weights is None: + n = x.shape[0] + weights = jnp.ones(n)/ n + + mean = weights.dot(x) + scaled_centered_x = (x - mean) * weights.reshape(-1, 1) + cov = (scaled_centered_x).T.dot(scaled_centered_x) / weights.T.dot(weights) + return cls.from_mean_and_cov(mean=mean, cov=cov) @classmethod def from_random( @@ -40,6 +64,7 @@ def from_random( stdev: float = 0.1, dtype: Optional[jnp.dtype] = None ) -> 'Gaussian': + """Construct a random Gaussian. Args: @@ -127,6 +152,17 @@ def w2_dist(self, other: 'Gaussian') -> jnp.ndarray: delta_sigma = self.scale.w2_dist(other.scale) return delta_mean + delta_sigma + def f_potential(self, dest: 'Gaussian', points: jnp.ndarray) -> jnp.ndarray: + scale_matrix = self.scale.transport_scale_matrix(dest_scale=dest.scale) + centered_x = points - self.loc + scaled_x = jnp.transpose(jnp.matmul(scale_matrix, jnp.transpose(centered_x))) + return ( + 0.5 * batch_inner_product(points, points) + - 0.5 * batch_inner_product(centered_x, scaled_x) + - (points).dot(dest.loc) + ) + + def transport(self, dest: 'Gaussian', points: jnp.ndarray) -> jnp.ndarray: return self.scale.transport( dest_scale=dest.scale, points=points - self.loc[None] diff --git a/ott/tools/gaussian_mixture/scale_tril.py b/ott/tools/gaussian_mixture/scale_tril.py index 3fd5ab632..9fcb652a6 100644 --- a/ott/tools/gaussian_mixture/scale_tril.py +++ b/ott/tools/gaussian_mixture/scale_tril.py @@ -157,17 +157,17 @@ def _flatten_cov(cov: jnp.ndarray) -> jnp.ndarray: return (cost_fn.norm(x0) + cost_fn.norm(x1) + cost_fn.pairwise(x0, x1))[...,] - def transport( - self, dest_scale: 'ScaleTriL', points: jnp.ndarray - ) -> jnp.ndarray: - """Transport between 0-mean normal w/ current scale to one w/ dest_scale. + def transport_scale_matrix(self, dest_scale: 'ScaleTriL') -> jnp.ndarray: + """ + Scaling matrix used in transport between 0-mean normal, \mu, w/ current scale to one w/ dest_scale, \nu + + m = \Sigma_\mu ^{-1/2} [ \Sigma_\mu ^{1/2} \Sigma_\nu \Sigma_\mu ^{1/2}] ^{1/2}\Sigma_\mu ^{-1/2} Args: dest_scale: destination Scale - points: points to transport Returns: - Points transported to a Gaussian with the new scale. + Gaussian scaling matrix, same dimension as self.covaraince() """ sqrt0, sqrt0_inv = linalg.matrix_powers(self.covariance(), (0.5, -0.5)) sigma1 = dest_scale.covariance() @@ -175,6 +175,22 @@ def transport( jnp.matmul(sqrt0, jnp.matmul(sigma1, sqrt0)) ) m = jnp.matmul(sqrt0_inv, jnp.matmul(m, sqrt0_inv)) + return m + + def transport( + self, dest_scale: 'ScaleTriL', points: jnp.ndarray + ) -> jnp.ndarray: + """Transport between 0-mean normal w/ current scale to one w/ dest_scale. + + Args: + dest_scale: destination Scale + points: points to transport + + Returns: + Points transported to a Gaussian with the new scale. + """ + + m = self.transport_scale_matrix(dest_scale) return jnp.transpose(jnp.matmul(m, jnp.transpose(points))) def tree_flatten(self): From 354951770173ffaaf588ee73bb82afe6649b2ec3 Mon Sep 17 00:00:00 2001 From: James Thornton Date: Fri, 1 Jul 2022 11:50:48 +0200 Subject: [PATCH 02/46] add initialization logic to sinkhorn --- ott/core/__init__.py | 2 + ott/core/initializers.py | 132 ++++++++++++++++++++++++++++++++------- ott/core/problems.py | 1 + ott/core/sinkhorn.py | 24 ++++--- 4 files changed, 128 insertions(+), 31 deletions(-) diff --git a/ott/core/__init__.py b/ott/core/__init__.py index cb962aa62..d671c3529 100644 --- a/ott/core/__init__.py +++ b/ott/core/__init__.py @@ -27,6 +27,7 @@ quad_problems, sinkhorn, sinkhorn_lr, + initializers ) # from . import neuraldual @@ -34,5 +35,6 @@ from .linear_problems import LinearProblem from .sinkhorn import Sinkhorn from .sinkhorn_lr import LRSinkhorn +from .initializers import SinkhornInitializer # pytype: enable=import-error # kwargs-checking diff --git a/ott/core/initializers.py b/ott/core/initializers.py index e168f0795..6f27e7d75 100644 --- a/ott/core/initializers.py +++ b/ott/core/initializers.py @@ -16,37 +16,84 @@ import functools import jax from jax import numpy as jnp +from typing import Optional + +from ott.core.linear_problems import LinearProblem +from ott.tools.gaussian_mixture.gaussian import Gaussian +from ott.geometry.pointcloud import PointCloud +from ott.core.problems import OTProblem -from .linear_problems import LinearProblem -from ..tools.gaussian_mixture.gaussian import Gaussian -from ..geometry.pointcloud import PointCloud @jax.tree_util.register_pytree_node_class class SinkhornInitializer(): - def apply(self, linear_problem: LinearProblem) -> jnp.ndarray: + def init_dual_a(self, ot_problem: OTProblem, lse_mode: bool = True) -> jnp.ndarray: """ Input: - linear_problem: OT problem between discrete distributions of size n and m + ot_problem: OT problem between discrete distributions of size n and m Return: dual potential, array of size m """ - pass + a = ot_problem.a + init_dual_a = jnp.zeros_like(a) if lse_mode else jnp.ones_like(a) + return init_dual_a + + def init_dual_b(self, ot_problem: OTProblem, lse_mode: bool = True) -> jnp.ndarray: + """ + Input: + ot_problem: OT problem between discrete distributions of size n and m + + Return: + dual potential, array of size m + """ + b = ot_problem.b + init_dual_b = jnp.zeros_like(b) if lse_mode else jnp.ones_like(b) + return init_dual_b + + def remove_null_weight_potentials(self, ot_problem, init_dual_a, init_dual_b, lse_mode: bool=True): + # Cancel dual variables for zero weights. + a, b = ot_problem.a, ot_problem.b + init_dual_a = jnp.where( + a > 0, init_dual_a, -jnp.inf if lse_mode else 0.0 + ) + init_dual_b = jnp.where( + b > 0, init_dual_b, -jnp.inf if lse_mode else 0.0 + ) + return init_dual_a, init_dual_b + + def default_dual_a(self, ot_problem, lse_mode): + a = ot_problem.a + init_dual_a = jnp.zeros_like(a) if lse_mode else jnp.ones_like(a) + return init_dual_a + + def default_dual_b(self, ot_problem, lse_mode): + b = ot_problem.b + init_dual_b = jnp.zeros_like(b) if lse_mode else jnp.ones_like(b) + return init_dual_b class GaussianInitializer(SinkhornInitializer): - def __init__(self, stop_gradient=True) -> None: + def __init__(self, stop_gradient: Optional[bool] =True) -> None: + """_summary_ + + Args: + stop_gradient (bool, optional): _description_. Defaults to True. + """ super().__init__() self.stop_gradient = stop_gradient - def apply(self, linear_problem: LinearProblem, init_f=None) -> jnp.ndarray: - + def init_dual_a(self, linear_problem: LinearProblem, init_f: Optional[jnp.ndarray] =None, lse_mode: bool = True) -> jnp.ndarray: + """_summary_ + + Returns: + _type_: _description_ + """ cost_matrix = linear_problem.geom.cost_matrix if self.stop_gradient: cost_matrix = jax.lax.stop_gradient(cost_matrix) @@ -60,30 +107,56 @@ def apply(self, linear_problem: LinearProblem, init_f=None) -> jnp.ndarray: else: x = linear_problem.geom.x y = linear_problem.geom.y - gaussian_a = Gaussian.from_samples(x, linear_problem.a) - gaussian_b = Gaussian.from_samples(y, linear_problem.b) - - f_potential = gaussian_a.f_potential(dest=gaussian_b, points=x) + gaussian_a = Gaussian.from_samples(x, weights=linear_problem.a) + gaussian_b = Gaussian.from_samples(y, weights=linear_problem.b) + f_potential = gaussian_a.f_potential(dest=gaussian_b, points=x) return f_potential class SortingInit(SinkhornInitializer): - def __init__(self, vector_min=False, tol=1e-2, max_iter=100, stop_gradient=True) -> None: + def __init__(self, + vector_min: Optional[bool] = False, + tol: Optional[float] = 1e-2, + max_iter: Optional[int] = 100, + stop_gradient: Optional[bool] = True) -> None: + """_summary_ + + Args: + vector_min (Optional[bool], optional): _description_. Defaults to False. + tol (Optional[float], optional): _description_. Defaults to 1e-2. + max_iter (Optional[int], optional): _description_. Defaults to 100. + stop_gradient (Optional[bool], optional): _description_. Defaults to True. + """ super().__init__() self.tolerance = tol self.stop_gradient = stop_gradient - self.max_iter = self.max_iter + self.max_iter = max_iter self.update_fn = self.vectorized_update if vector_min else self.coordinate_update - def vectorized_update(self, f, modified_cost): + def vectorized_update(self, f: jnp.ndarray, modified_cost: jnp.ndarray): + """_summary_ + + Args: + f (jnp.ndarray): _description_ + modified_cost (jnp.ndarray): _description_ + + Returns: + _type_: _description_ + """ f = jnp.min(modified_cost + f[None, :], axis=1) return f @jax.jit - def coordinate_update(self, f, modified_cost): + def coordinate_update(self, f: jnp.ndarray, modified_cost: jnp.ndarray): + """_summary_ + + Args: + f (jnp.ndarray): _description_ + modified_cost (jnp.ndarray): _description_ + """ def body_fn(i, f): new_f = jnp.min(modified_cost[i, :] + f) @@ -93,7 +166,16 @@ def body_fn(i, f): return jax.lax.fori_loop(0, len(f), body_fn, f) @functools.partial(jax.jit, static_argnums=(1, 2, 3)) - def init_sorting_dual(self, modified_cost, f_potential): + def init_sorting_dual(self, modified_cost: jnp.ndarray, f_potential: jnp.ndarray): + """_summary_ + + Args: + modified_cost (jnp.ndarray): _description_ + f_potential (jnp.ndarray): _description_ + + Returns: + _type_: _description_ + """ it = 0 diff = self.tolerance + 1.0 @@ -110,11 +192,19 @@ def cond_fn(state): return (diff > self.tolerance) & (it < self.mat_iter) f_potential, _, it = jax.lax.while_loop(cond_fun=cond_fn, body_fun=body_fn, init_val=state) - + return f_potential - def apply(self, linear_problem: LinearProblem, init_f=None) -> jnp.ndarray: - + def init_dual_a(self, linear_problem: LinearProblem, init_f: jnp.ndarray = None, lse_mode: bool = True) -> jnp.ndarray: + """ + + Args: + linear_problem (LinearProblem): _description_ + init_f (jnp.ndarray, optional): _description_. Defaults to None. + + Returns: + jnp.ndarray: _description_ + """ cost_matrix = linear_problem.geom.cost_matrix if self.stop_gradient: cost_matrix = jax.lax.stop_gradient(cost_matrix) diff --git a/ott/core/problems.py b/ott/core/problems.py index e60b4deb1..d8ae6e148 100644 --- a/ott/core/problems.py +++ b/ott/core/problems.py @@ -20,6 +20,7 @@ from ott.core import linear_problems, quad_problems from ott.geometry import geometry, pointcloud +OTProblem = Union[linear_problems.LinearProblem, quad_problems.QuadraticProblem] def make( *args: Union[jnp.ndarray, geometry.Geometry, linear_problems.LinearProblem, diff --git a/ott/core/sinkhorn.py b/ott/core/sinkhorn.py index 1fe0c99d7..f0de26062 100644 --- a/ott/core/sinkhorn.py +++ b/ott/core/sinkhorn.py @@ -14,6 +14,7 @@ # Lint as: python3 """A Jax implementation of the Sinkhorn algorithm.""" +from pickle import NONE from typing import Any, Callable, NamedTuple, Optional, Sequence, Tuple import jax @@ -26,6 +27,7 @@ from ott.core import linear_problems from ott.core import momentum as momentum_lib from ott.core import unbalanced_functions +from ott.core import initializers as init_lib from ott.geometry import geometry @@ -349,6 +351,7 @@ def __init__( use_danskin: Optional[bool] = None, implicit_diff: Optional[implicit_lib.ImplicitDiff ] = implicit_lib.ImplicitDiff(), # noqa: E124 + potential_initializer: Optional[init_lib.SinkhornInitializer] = init_lib.SinkhornInitializer(), jit: bool = True ): self.lse_mode = lse_mode @@ -368,6 +371,7 @@ def __init__( self.anderson = anderson self.implicit_diff = implicit_diff self.parallel_dual_updates = parallel_dual_updates + self.potential_initializer = potential_initializer self.jit = jit # Force implicit_differentiation to True when using Anderson acceleration, @@ -400,20 +404,20 @@ def __call__( init: Optional[Tuple[Optional[jnp.ndarray], Optional[jnp.ndarray]]] = None ) -> SinkhornOutput: """Main interface to run sinkhorn.""" # noqa: D401 + + # initialization init_dual_a, init_dual_b = (init if init is not None else (None, None)) - a, b = ot_prob.a, ot_prob.b + if init_dual_a is None: - init_dual_a = jnp.zeros_like(a) if self.lse_mode else jnp.ones_like(a) + init_dual_a = self.potential_initializer.init_dual_a(ot_problem=ot_prob, lse_mode=self.lse_mode) + if init_dual_b is None: - init_dual_b = jnp.zeros_like(b) if self.lse_mode else jnp.ones_like(b) - # Cancel dual variables for zero weights. - init_dual_a = jnp.where( - a > 0, init_dual_a, -jnp.inf if self.lse_mode else 0.0 - ) - init_dual_b = jnp.where( - b > 0, init_dual_b, -jnp.inf if self.lse_mode else 0.0 - ) + init_dual_b = self.potential_initializer.init_dual_b(ot_problem=ot_prob, lse_mode=self.lse_mode) + # Cancel dual variables for zero weights. + init_dual_a, init_dual_b = self.potential_initializer.remove_null_weight_potentials(ot_problem=ot_prob, + init_dual_a=init_dual_a, + init_dual_b=init_dual_b) run_fn = jax.jit(run) if self.jit else run return run_fn(ot_prob, self, (init_dual_a, init_dual_b)) From d8cdfd3921e235476d668c2e9d7a2b1b0148aa33 Mon Sep 17 00:00:00 2001 From: James Thornton Date: Fri, 1 Jul 2022 12:18:37 +0200 Subject: [PATCH 03/46] remove general ot problem type --- ott/core/initializers.py | 6 +++--- ott/core/problems.py | 2 -- ott/tools/gaussian_mixture/gaussian.py | 18 ++++++++++++++++++ 3 files changed, 21 insertions(+), 5 deletions(-) diff --git a/ott/core/initializers.py b/ott/core/initializers.py index 6f27e7d75..d60e6e909 100644 --- a/ott/core/initializers.py +++ b/ott/core/initializers.py @@ -21,13 +21,13 @@ from ott.core.linear_problems import LinearProblem from ott.tools.gaussian_mixture.gaussian import Gaussian from ott.geometry.pointcloud import PointCloud -from ott.core.problems import OTProblem + @jax.tree_util.register_pytree_node_class class SinkhornInitializer(): - def init_dual_a(self, ot_problem: OTProblem, lse_mode: bool = True) -> jnp.ndarray: + def init_dual_a(self, ot_problem: LinearProblem, lse_mode: bool = True) -> jnp.ndarray: """ Input: ot_problem: OT problem between discrete distributions of size n and m @@ -40,7 +40,7 @@ def init_dual_a(self, ot_problem: OTProblem, lse_mode: bool = True) -> jnp.ndarr return init_dual_a - def init_dual_b(self, ot_problem: OTProblem, lse_mode: bool = True) -> jnp.ndarray: + def init_dual_b(self, ot_problem: LinearProblem, lse_mode: bool = True) -> jnp.ndarray: """ Input: ot_problem: OT problem between discrete distributions of size n and m diff --git a/ott/core/problems.py b/ott/core/problems.py index d8ae6e148..ce9ee328b 100644 --- a/ott/core/problems.py +++ b/ott/core/problems.py @@ -20,8 +20,6 @@ from ott.core import linear_problems, quad_problems from ott.geometry import geometry, pointcloud -OTProblem = Union[linear_problems.LinearProblem, quad_problems.QuadraticProblem] - def make( *args: Union[jnp.ndarray, geometry.Geometry, linear_problems.LinearProblem, quad_problems.QuadraticProblem], diff --git a/ott/tools/gaussian_mixture/gaussian.py b/ott/tools/gaussian_mixture/gaussian.py index 9f85a4f1b..ed69bfdb3 100644 --- a/ott/tools/gaussian_mixture/gaussian.py +++ b/ott/tools/gaussian_mixture/gaussian.py @@ -153,6 +153,15 @@ def w2_dist(self, other: 'Gaussian') -> jnp.ndarray: return delta_mean + delta_sigma def f_potential(self, dest: 'Gaussian', points: jnp.ndarray) -> jnp.ndarray: + """_summary_ + + Args: + dest (Gaussian): _description_ + points (jnp.ndarray): _description_ + + Returns: + jnp.ndarray: _description_ + """ scale_matrix = self.scale.transport_scale_matrix(dest_scale=dest.scale) centered_x = points - self.loc scaled_x = jnp.transpose(jnp.matmul(scale_matrix, jnp.transpose(centered_x))) @@ -164,6 +173,15 @@ def f_potential(self, dest: 'Gaussian', points: jnp.ndarray) -> jnp.ndarray: def transport(self, dest: 'Gaussian', points: jnp.ndarray) -> jnp.ndarray: + """_summary_ + + Args: + dest (Gaussian): _description_ + points (jnp.ndarray): _description_ + + Returns: + jnp.ndarray: _description_ + """ return self.scale.transport( dest_scale=dest.scale, points=points - self.loc[None] ) + dest.loc[None] From 9050efb36015a675dc0aab99aafedc27e77023ff Mon Sep 17 00:00:00 2001 From: James Thornton Date: Fri, 1 Jul 2022 12:24:23 +0200 Subject: [PATCH 04/46] remove import tools.gaussian from top level --- ott/core/initializers.py | 10 ++++++---- ott/core/problems.py | 2 ++ 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/ott/core/initializers.py b/ott/core/initializers.py index d60e6e909..551edbd69 100644 --- a/ott/core/initializers.py +++ b/ott/core/initializers.py @@ -19,15 +19,14 @@ from typing import Optional from ott.core.linear_problems import LinearProblem -from ott.tools.gaussian_mixture.gaussian import Gaussian from ott.geometry.pointcloud import PointCloud - +from ott.core.problems import OTProblem @jax.tree_util.register_pytree_node_class class SinkhornInitializer(): - def init_dual_a(self, ot_problem: LinearProblem, lse_mode: bool = True) -> jnp.ndarray: + def init_dual_a(self, ot_problem: OTProblem, lse_mode: bool = True) -> jnp.ndarray: """ Input: ot_problem: OT problem between discrete distributions of size n and m @@ -40,7 +39,7 @@ def init_dual_a(self, ot_problem: LinearProblem, lse_mode: bool = True) -> jnp.n return init_dual_a - def init_dual_b(self, ot_problem: LinearProblem, lse_mode: bool = True) -> jnp.ndarray: + def init_dual_b(self, ot_problem: OTProblem, lse_mode: bool = True) -> jnp.ndarray: """ Input: ot_problem: OT problem between discrete distributions of size n and m @@ -94,6 +93,8 @@ def init_dual_a(self, linear_problem: LinearProblem, init_f: Optional[jnp.ndarra Returns: _type_: _description_ """ + from ott.tools.gaussian_mixture.gaussian import Gaussian + cost_matrix = linear_problem.geom.cost_matrix if self.stop_gradient: cost_matrix = jax.lax.stop_gradient(cost_matrix) @@ -216,6 +217,7 @@ def init_dual_a(self, linear_problem: LinearProblem, init_f: jnp.ndarray = None, f_potential = self.init_sorting_dual(modified_cost, f_potential) + return f_potential diff --git a/ott/core/problems.py b/ott/core/problems.py index ce9ee328b..d8ae6e148 100644 --- a/ott/core/problems.py +++ b/ott/core/problems.py @@ -20,6 +20,8 @@ from ott.core import linear_problems, quad_problems from ott.geometry import geometry, pointcloud +OTProblem = Union[linear_problems.LinearProblem, quad_problems.QuadraticProblem] + def make( *args: Union[jnp.ndarray, geometry.Geometry, linear_problems.LinearProblem, quad_problems.QuadraticProblem], From a680d6b3541db5829bf16635dc590f79127f939d Mon Sep 17 00:00:00 2001 From: James Thornton Date: Fri, 1 Jul 2022 12:26:05 +0200 Subject: [PATCH 05/46] remove problems from top level --- ott/core/initializers.py | 8 ++++---- ott/core/problems.py | 2 -- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/ott/core/initializers.py b/ott/core/initializers.py index 551edbd69..8f1fc710b 100644 --- a/ott/core/initializers.py +++ b/ott/core/initializers.py @@ -20,13 +20,13 @@ from ott.core.linear_problems import LinearProblem from ott.geometry.pointcloud import PointCloud -from ott.core.problems import OTProblem + @jax.tree_util.register_pytree_node_class class SinkhornInitializer(): - def init_dual_a(self, ot_problem: OTProblem, lse_mode: bool = True) -> jnp.ndarray: + def init_dual_a(self, ot_problem: LinearProblem, lse_mode: bool = True) -> jnp.ndarray: """ Input: ot_problem: OT problem between discrete distributions of size n and m @@ -39,7 +39,7 @@ def init_dual_a(self, ot_problem: OTProblem, lse_mode: bool = True) -> jnp.ndarr return init_dual_a - def init_dual_b(self, ot_problem: OTProblem, lse_mode: bool = True) -> jnp.ndarray: + def init_dual_b(self, ot_problem: LinearProblem, lse_mode: bool = True) -> jnp.ndarray: """ Input: ot_problem: OT problem between discrete distributions of size n and m @@ -94,7 +94,7 @@ def init_dual_a(self, linear_problem: LinearProblem, init_f: Optional[jnp.ndarra _type_: _description_ """ from ott.tools.gaussian_mixture.gaussian import Gaussian - + cost_matrix = linear_problem.geom.cost_matrix if self.stop_gradient: cost_matrix = jax.lax.stop_gradient(cost_matrix) diff --git a/ott/core/problems.py b/ott/core/problems.py index d8ae6e148..ce9ee328b 100644 --- a/ott/core/problems.py +++ b/ott/core/problems.py @@ -20,8 +20,6 @@ from ott.core import linear_problems, quad_problems from ott.geometry import geometry, pointcloud -OTProblem = Union[linear_problems.LinearProblem, quad_problems.QuadraticProblem] - def make( *args: Union[jnp.ndarray, geometry.Geometry, linear_problems.LinearProblem, quad_problems.QuadraticProblem], From 378777e784f45f81f19cf62e4ad5f51c98edb292 Mon Sep 17 00:00:00 2001 From: James Thornton Date: Fri, 1 Jul 2022 12:28:40 +0200 Subject: [PATCH 06/46] do not register initializer as pytree --- ott/core/initializers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ott/core/initializers.py b/ott/core/initializers.py index 8f1fc710b..8aba93423 100644 --- a/ott/core/initializers.py +++ b/ott/core/initializers.py @@ -23,7 +23,6 @@ -@jax.tree_util.register_pytree_node_class class SinkhornInitializer(): def init_dual_a(self, ot_problem: LinearProblem, lse_mode: bool = True) -> jnp.ndarray: From 7fa567a3c3aa2db7fa1af407f54a9099e17d4247 Mon Sep 17 00:00:00 2001 From: James Thornton Date: Fri, 1 Jul 2022 12:33:55 +0200 Subject: [PATCH 07/46] add initializer to make --- ott/core/sinkhorn.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ott/core/sinkhorn.py b/ott/core/sinkhorn.py index f0de26062..b0c78e2e9 100644 --- a/ott/core/sinkhorn.py +++ b/ott/core/sinkhorn.py @@ -410,7 +410,7 @@ def __call__( if init_dual_a is None: init_dual_a = self.potential_initializer.init_dual_a(ot_problem=ot_prob, lse_mode=self.lse_mode) - + if init_dual_b is None: init_dual_b = self.potential_initializer.init_dual_b(ot_problem=ot_prob, lse_mode=self.lse_mode) @@ -695,6 +695,7 @@ def make( precondition_fun: Optional[Callable[[float], float]] = None, parallel_dual_updates: bool = False, use_danskin: bool = None, + potential_initializer: Optional[init_lib.SinkhornInitializer] = init_lib.SinkhornInitializer(), jit: bool = False ) -> Sinkhorn: """For backward compatibility.""" @@ -729,6 +730,7 @@ def make( implicit_diff=implicit_diff, parallel_dual_updates=parallel_dual_updates, use_danskin=use_danskin, + potential_initializer=potential_initializer, jit=jit ) From 4ce8357c498672f5150b2add7a78636e51696ef6 Mon Sep 17 00:00:00 2001 From: James Thornton Date: Fri, 1 Jul 2022 12:37:53 +0200 Subject: [PATCH 08/46] rename init arg to ot_problem --- ott/core/initializers.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/ott/core/initializers.py b/ott/core/initializers.py index 8aba93423..07949d1ee 100644 --- a/ott/core/initializers.py +++ b/ott/core/initializers.py @@ -18,7 +18,7 @@ from jax import numpy as jnp from typing import Optional -from ott.core.linear_problems import LinearProblem +from ott.core.ot_problems import LinearProblem from ott.geometry.pointcloud import PointCloud @@ -85,7 +85,7 @@ def __init__(self, stop_gradient: Optional[bool] =True) -> None: self.stop_gradient = stop_gradient - def init_dual_a(self, linear_problem: LinearProblem, init_f: Optional[jnp.ndarray] =None, lse_mode: bool = True) -> jnp.ndarray: + def init_dual_a(self, ot_problem: LinearProblem, init_f: Optional[jnp.ndarray] =None, lse_mode: bool = True) -> jnp.ndarray: """_summary_ @@ -94,21 +94,21 @@ def init_dual_a(self, linear_problem: LinearProblem, init_f: Optional[jnp.ndarra """ from ott.tools.gaussian_mixture.gaussian import Gaussian - cost_matrix = linear_problem.geom.cost_matrix + cost_matrix = ot_problem.geom.cost_matrix if self.stop_gradient: cost_matrix = jax.lax.stop_gradient(cost_matrix) n = cost_matrix.shape[0] f_potential = jnp.zeros(n) if init_f is None else init_f - if not isinstance(linear_problem.geom, PointCloud): + if not isinstance(ot_problem.geom, PointCloud): return f_potential else: - x = linear_problem.geom.x - y = linear_problem.geom.y - gaussian_a = Gaussian.from_samples(x, weights=linear_problem.a) - gaussian_b = Gaussian.from_samples(y, weights=linear_problem.b) + x = ot_problem.geom.x + y = ot_problem.geom.y + gaussian_a = Gaussian.from_samples(x, weights=ot_problem.a) + gaussian_b = Gaussian.from_samples(y, weights=ot_problem.b) f_potential = gaussian_a.f_potential(dest=gaussian_b, points=x) return f_potential @@ -195,17 +195,17 @@ def cond_fn(state): return f_potential - def init_dual_a(self, linear_problem: LinearProblem, init_f: jnp.ndarray = None, lse_mode: bool = True) -> jnp.ndarray: + def init_dual_a(self, ot_problem: LinearProblem, init_f: jnp.ndarray = None, lse_mode: bool = True) -> jnp.ndarray: """ Args: - linear_problem (LinearProblem): _description_ + ot_problem (LinearProblem): _description_ init_f (jnp.ndarray, optional): _description_. Defaults to None. Returns: jnp.ndarray: _description_ """ - cost_matrix = linear_problem.geom.cost_matrix + cost_matrix = ot_problem.geom.cost_matrix if self.stop_gradient: cost_matrix = jax.lax.stop_gradient(cost_matrix) From 2856b4b1bc88a750b27bf9a3c271b4f6b9bb80cd Mon Sep 17 00:00:00 2001 From: James Thornton Date: Fri, 1 Jul 2022 12:38:55 +0200 Subject: [PATCH 09/46] rename init arg to ot_problem --- ott/core/initializers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ott/core/initializers.py b/ott/core/initializers.py index 07949d1ee..9422697cd 100644 --- a/ott/core/initializers.py +++ b/ott/core/initializers.py @@ -18,7 +18,7 @@ from jax import numpy as jnp from typing import Optional -from ott.core.ot_problems import LinearProblem +from ott.core.linear_problems import LinearProblem from ott.geometry.pointcloud import PointCloud From 42eb327b6f5f9442c8a96d33371bd7ade22dd869 Mon Sep 17 00:00:00 2001 From: James Thornton Date: Fri, 1 Jul 2022 14:44:05 +0200 Subject: [PATCH 10/46] scale gaus init by 2 --- ott/core/initializers.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/ott/core/initializers.py b/ott/core/initializers.py index 9422697cd..cc4e06c22 100644 --- a/ott/core/initializers.py +++ b/ott/core/initializers.py @@ -102,16 +102,15 @@ def init_dual_a(self, ot_problem: LinearProblem, init_f: Optional[jnp.ndarray] = f_potential = jnp.zeros(n) if init_f is None else init_f if not isinstance(ot_problem.geom, PointCloud): + # warning that init not applied return f_potential - else: x = ot_problem.geom.x y = ot_problem.geom.y gaussian_a = Gaussian.from_samples(x, weights=ot_problem.a) gaussian_b = Gaussian.from_samples(y, weights=ot_problem.b) - f_potential = gaussian_a.f_potential(dest=gaussian_b, points=x) - - return f_potential + f_potential = 2*gaussian_a.f_potential(dest=gaussian_b, points=x) + return f_potential class SortingInit(SinkhornInitializer): @@ -148,8 +147,6 @@ def vectorized_update(self, f: jnp.ndarray, modified_cost: jnp.ndarray): f = jnp.min(modified_cost + f[None, :], axis=1) return f - - @jax.jit def coordinate_update(self, f: jnp.ndarray, modified_cost: jnp.ndarray): """_summary_ @@ -165,7 +162,6 @@ def body_fn(i, f): return jax.lax.fori_loop(0, len(f), body_fn, f) - @functools.partial(jax.jit, static_argnums=(1, 2, 3)) def init_sorting_dual(self, modified_cost: jnp.ndarray, f_potential: jnp.ndarray): """_summary_ @@ -189,7 +185,7 @@ def body_fn(state): def cond_fn(state): _, diff, it = state - return (diff > self.tolerance) & (it < self.mat_iter) + return (diff > self.tolerance) & (it < self.matxiter) f_potential, _, it = jax.lax.while_loop(cond_fun=cond_fn, body_fun=body_fn, init_val=state) From 2de12ccaae1d6a0855f2b37383bdb6a2eaac457a Mon Sep 17 00:00:00 2001 From: James Thornton Date: Fri, 1 Jul 2022 14:45:17 +0200 Subject: [PATCH 11/46] typo --- ott/core/initializers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ott/core/initializers.py b/ott/core/initializers.py index cc4e06c22..cb471a880 100644 --- a/ott/core/initializers.py +++ b/ott/core/initializers.py @@ -185,7 +185,7 @@ def body_fn(state): def cond_fn(state): _, diff, it = state - return (diff > self.tolerance) & (it < self.matxiter) + return (diff > self.tolerance) & (it < self.max_iter) f_potential, _, it = jax.lax.while_loop(cond_fun=cond_fn, body_fun=body_fn, init_val=state) From 4d805089e594a3141f256b81a439c0acb450a121 Mon Sep 17 00:00:00 2001 From: James Thornton Date: Fri, 1 Jul 2022 15:19:35 +0200 Subject: [PATCH 12/46] add basic speed tests --- tests/core/initializers_test.py | 153 ++++++++++++++++++++++++++++++++ 1 file changed, 153 insertions(+) create mode 100644 tests/core/initializers_test.py diff --git a/tests/core/initializers_test.py b/tests/core/initializers_test.py new file mode 100644 index 000000000..3bb7a5a99 --- /dev/null +++ b/tests/core/initializers_test.py @@ -0,0 +1,153 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""Tests for the Gromov Wasserstein.""" + +import jax +import jax.numpy as jnp +import jax.test_util +import numpy as np +from absl.testing import absltest, parameterized + + +from ott.core.sinkhorn import sinkhorn +from ott.geometry.pointcloud import PointCloud +from ott.core import initializers as init_lib + + + + + +class InitializerTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.rng = jax.random.PRNGKey(0) + + def test_sorting_init(self): + """Tests sorting dual initializer.""" + + # init initializer + sort_init = init_lib.SortingInit(vector_min=True) + + # define sinkhorn functions + @jax.jit + def run_sinkhorn_sort_init(x, y, a=None, b=None, init_dual_a=None): + sink_kwargs = {'jit': True, + 'threshold': 0.001, + 'max_iterations': 10**5, + 'potential_initializer': sort_init} + geom_kwargs = {'epsilon': 0.01} + geom = PointCloud(x, y, **geom_kwargs) + out = sinkhorn(geom, a=a, b=b, init_dual_a=init_dual_a, **sink_kwargs) + return out + + @jax.jit + def run_sinkhorn(x, y, a=None, b=None, init_dual_a=None): + sink_kwargs = {'jit': True, 'threshold': 0.001, 'max_iterations': 10**5} + geom_kwargs = {'epsilon': 0.01} + geom = PointCloud(x, y, **geom_kwargs) + out = sinkhorn(geom, a=a, b=b, init_dual_a=init_dual_a, **sink_kwargs) + return out + + # definte ot problem + x_init = np.array([-1., 0, .22]) + y_init = np.array([0., 0, 1.1]) + + buf = 500 + np.random.seed(0) + x = np.concatenate([x_init, 10 + np.abs(np.random.normal(size=buf))])*5 + y = np.concatenate([y_init, 10 + np.abs(np.random.normal(size=buf))])*5 + + x = np.sort(x) + y = np.sort(y) + + n = len(x) + m = len(y) + a = np.ones(n)/n + b = np.ones(m)/m + + x_jnp, y_jnp = jnp.array(x.reshape(-1,1)), jnp.array(y.reshape(-1,1)) + + # run sinkhorn + sink_out = run_sinkhorn(x=x_jnp,y=y_jnp,a=a, b=b) + base_num_iter = jnp.sum(sink_out.errors > -1) + + + sink_out = run_sinkhorn_sort_init(x=x_jnp,y=y_jnp, a=a, b=b) + sort_num_iter = jnp.sum(sink_out.errors > -1) + + # check initializer is better + self.assertTrue(base_num_iter >= sort_num_iter) + + + def test_gaus_initializer(self): + """Tests Gaussian initializer""" + + # init initializer + gaus_init = init_lib.GaussianInitializer() + + + @jax.jit + def run_sinkhorn(x, y, a=None, b=None, init_dual_a=None): + sink_kwargs = {'jit': True, 'threshold': 0.001, 'max_iterations': 10**5} + geom_kwargs = {'epsilon': 0.01} + geom = PointCloud(x, y, **geom_kwargs) + out = sinkhorn(geom, a=a, b=b, init_dual_a=init_dual_a, **sink_kwargs) + return out + + + @jax.jit + def run_sinkhorn_gaus_init(x, y, a=None, b=None, init_dual_a=None): + sink_kwargs = {'jit': True, + 'threshold': 0.001, + 'max_iterations': 10**5, + 'potential_initializer': gaus_init} + + geom_kwargs = {'epsilon': 0.01} + geom = PointCloud(x, y, **geom_kwargs) + out = sinkhorn(geom, a=a, b=b, init_dual_a=init_dual_a, **sink_kwargs) + return out + + + # definte ot problem + np.random.seed(0) + n, d = 1000, 2 + mu_a = np.array([-1,1])*5 + mu_b = np.array([0,0]) + + + x = np.random.normal(size=n*d).reshape(n,d) + mu_a + y = np.random.normal(size=n*d).reshape(n,d) + mu_b + + + n = len(x) + m = len(y) + a = np.ones(n)/n + b = np.ones(m)/m + + x_jnp, y_jnp = jnp.array(x), jnp.array(y) + + # run sinkhorn + sink_out = run_sinkhorn(x=x_jnp,y=y_jnp,a=a, b=b) + base_num_iter = jnp.sum(sink_out.errors > -1) + + sink_out = run_sinkhorn_gaus_init(x=x_jnp,y=y_jnp, a=a, b=b) + gaus_num_iter = jnp.sum(sink_out.errors > -1) + + # check initializer is better + self.assertTrue(base_num_iter >= gaus_num_iter) + + +if __name__ == '__main__': + absltest.main() \ No newline at end of file From 23a03e9dc1401b9f95a043bf28c366b99b3f9498 Mon Sep 17 00:00:00 2001 From: James Thornton Date: Fri, 1 Jul 2022 16:19:14 +0200 Subject: [PATCH 13/46] add init to transport tools wrapper, tidy docstring --- ott/core/initializers.py | 11 ++++++----- ott/tools/gaussian_mixture/scale_tril.py | 4 ++-- ott/tools/transport.py | 9 +++++++-- 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/ott/core/initializers.py b/ott/core/initializers.py index cb471a880..436f1d54f 100644 --- a/ott/core/initializers.py +++ b/ott/core/initializers.py @@ -92,24 +92,24 @@ def init_dual_a(self, ot_problem: LinearProblem, init_f: Optional[jnp.ndarray] = Returns: _type_: _description_ """ + # import here due to circular imports from ott.tools.gaussian_mixture.gaussian import Gaussian cost_matrix = ot_problem.geom.cost_matrix if self.stop_gradient: cost_matrix = jax.lax.stop_gradient(cost_matrix) - n = cost_matrix.shape[0] - f_potential = jnp.zeros(n) if init_f is None else init_f - if not isinstance(ot_problem.geom, PointCloud): # warning that init not applied - return f_potential + return self.default_dual_a(ot_problem, lse_mode) else: x = ot_problem.geom.x y = ot_problem.geom.y gaussian_a = Gaussian.from_samples(x, weights=ot_problem.a) gaussian_b = Gaussian.from_samples(y, weights=ot_problem.b) - f_potential = 2*gaussian_a.f_potential(dest=gaussian_b, points=x) + # Brenier potential for ground cost ||x-y||^2/2, so multiple by two for cost ||x-y||^2 + f_potential = 2*gaussian_a.f_potential(dest=gaussian_b, points=x) + f_potential = f_potential if lse_mode else jnp.exp(f_potential) return f_potential class SortingInit(SinkhornInitializer): @@ -212,6 +212,7 @@ def init_dual_a(self, ot_problem: LinearProblem, init_f: jnp.ndarray = None, lse f_potential = self.init_sorting_dual(modified_cost, f_potential) + f_potential = f_potential if lse_mode else jnp.exp(f_potential) return f_potential diff --git a/ott/tools/gaussian_mixture/scale_tril.py b/ott/tools/gaussian_mixture/scale_tril.py index 9fcb652a6..35cfa6af2 100644 --- a/ott/tools/gaussian_mixture/scale_tril.py +++ b/ott/tools/gaussian_mixture/scale_tril.py @@ -159,9 +159,9 @@ def _flatten_cov(cov: jnp.ndarray) -> jnp.ndarray: def transport_scale_matrix(self, dest_scale: 'ScaleTriL') -> jnp.ndarray: """ - Scaling matrix used in transport between 0-mean normal, \mu, w/ current scale to one w/ dest_scale, \nu + Scaling matrix used in transport between 0-mean normal, mu, w/ current scale to one w/ dest_scale, nu - m = \Sigma_\mu ^{-1/2} [ \Sigma_\mu ^{1/2} \Sigma_\nu \Sigma_\mu ^{1/2}] ^{1/2}\Sigma_\mu ^{-1/2} + m = Sigma_mu ^{-1/2} [ Sigma_mu ^{1/2} Sigma_nu Sigma_mu ^{1/2}] ^{1/2}Sigma_mu ^{-1/2} Args: dest_scale: destination Scale diff --git a/ott/tools/transport.py b/ott/tools/transport.py index da66df895..bc694202f 100644 --- a/ott/tools/transport.py +++ b/ott/tools/transport.py @@ -121,9 +121,14 @@ def solve( linear = isinstance(pb, linear_problems.LinearProblem) solver_fn = sinkhorn.make if linear else gromov_wasserstein.make geom_keys = ['cost_fn', 'power', 'online'] - remove_keys = geom_keys + eps_keys if linear else geom_keys + + init_dual_a = kwargs.get('init_dual_a', None) + init_dual_b = kwargs.get('init_dual_b', None) + init_keys = ['init_dual_a', 'init_dual_b'] + + remove_keys = init_keys + geom_keys + eps_keys if linear else geom_keys for key in remove_keys: kwargs.pop(key, None) solver = solver_fn(**kwargs) - output = solver(pb) + output = solver(pb, (init_dual_a, init_dual_b)) return Transport(pb, output) From 512150f64ff246aabc1490dbe30909f69087873a Mon Sep 17 00:00:00 2001 From: James Thornton Date: Fri, 1 Jul 2022 16:26:14 +0200 Subject: [PATCH 14/46] ceneter potentials in initializers --- ott/core/initializers.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ott/core/initializers.py b/ott/core/initializers.py index 436f1d54f..5e7597ced 100644 --- a/ott/core/initializers.py +++ b/ott/core/initializers.py @@ -109,6 +109,7 @@ def init_dual_a(self, ot_problem: LinearProblem, init_f: Optional[jnp.ndarray] = gaussian_b = Gaussian.from_samples(y, weights=ot_problem.b) # Brenier potential for ground cost ||x-y||^2/2, so multiple by two for cost ||x-y||^2 f_potential = 2*gaussian_a.f_potential(dest=gaussian_b, points=x) + f_potential = f_potential - jnp.mean(f_potential) f_potential = f_potential if lse_mode else jnp.exp(f_potential) return f_potential @@ -211,9 +212,10 @@ def init_dual_a(self, ot_problem: LinearProblem, init_f: jnp.ndarray = None, lse f_potential = jnp.zeros(n) if init_f is None else init_f f_potential = self.init_sorting_dual(modified_cost, f_potential) + f_potential = f_potential - jnp.mean(f_potential) f_potential = f_potential if lse_mode else jnp.exp(f_potential) - + return f_potential From c90cd40036a50769b2feecb108330af8751635c2 Mon Sep 17 00:00:00 2001 From: James Thornton Date: Sun, 3 Jul 2022 20:26:13 +0200 Subject: [PATCH 15/46] fix lse for null weights --- ott/core/initializers.py | 130 ++++++++++++++++------- ott/core/sinkhorn.py | 5 +- tests/core/continuous_barycenter_test.py | 2 +- tests/core/initializers_test.py | 51 ++++++++- tests/core/sinkhorn_test.py | 29 ++++- 5 files changed, 166 insertions(+), 51 deletions(-) diff --git a/ott/core/initializers.py b/ott/core/initializers.py index 5e7597ced..8cf8ae313 100644 --- a/ott/core/initializers.py +++ b/ott/core/initializers.py @@ -13,10 +13,11 @@ # limitations under the License. """Sinkhorn initializers.""" +from ctypes import Union import functools import jax from jax import numpy as jnp -from typing import Optional +from typing import Optional, Tuple from ott.core.linear_problems import LinearProblem from ott.geometry.pointcloud import PointCloud @@ -27,31 +28,54 @@ class SinkhornInitializer(): def init_dual_a(self, ot_problem: LinearProblem, lse_mode: bool = True) -> jnp.ndarray: """ - Input: - ot_problem: OT problem between discrete distributions of size n and m - - Return: - dual potential, array of size m + + Initialzation for Sinkhorn potential f + + Args: + ot_problem (LinearProblem): OT problem between discrete distributions of size n and m + lse_mode (bool, optional): Return log potential. Defaults to True. + + Returns: + jnp.ndarray: dual potential, array of size n """ - a = ot_problem.a - init_dual_a = jnp.zeros_like(a) if lse_mode else jnp.ones_like(a) - return init_dual_a + + return self.default_dual_a(ot_problem=ot_problem, lse_mode=lse_mode) def init_dual_b(self, ot_problem: LinearProblem, lse_mode: bool = True) -> jnp.ndarray: """ - Input: - ot_problem: OT problem between discrete distributions of size n and m - - Return: - dual potential, array of size m + + Initialzation for Sinkhorn potential g + + Args: + ot_problem (LinearProblem): OT problem between discrete distributions of size n and m + lse_mode (bool, optional): Return log potential. Defaults to True. + + Returns: + jnp.ndarray: dual potential, array of size m """ - b = ot_problem.b - init_dual_b = jnp.zeros_like(b) if lse_mode else jnp.ones_like(b) - return init_dual_b + + return self.default_dual_b(ot_problem=ot_problem, lse_mode=lse_mode) - def remove_null_weight_potentials(self, ot_problem, init_dual_a, init_dual_b, lse_mode: bool=True): - # Cancel dual variables for zero weights. + def remove_null_weight_potentials(self, + ot_problem: LinearProblem, + init_dual_a: jnp.ndarray, + init_dual_b: jnp.ndarray, + lse_mode: bool=True) -> Tuple[jnp.ndarray]: + + """ + Cancel dual variables for zero weights. + + Args: + ot_problem (LinearProblem): + init_dual_a (jnp.ndarray): potential f, array of size n + init_dual_b (jnp.ndarray): potential g, array of size m + lse_mode (bool, optional): Return log potentials if true. Defaults to True. + + Returns: + Union[jnp.ndarray]: potentials (f,g) + """ + a, b = ot_problem.a, ot_problem.b init_dual_a = jnp.where( a > 0, init_dual_a, -jnp.inf if lse_mode else 0.0 @@ -61,12 +85,34 @@ def remove_null_weight_potentials(self, ot_problem, init_dual_a, init_dual_b, ls ) return init_dual_a, init_dual_b - def default_dual_a(self, ot_problem, lse_mode): + def default_dual_a(self, ot_problem: LinearProblem, lse_mode: bool = True) -> jnp.ndarray: + """ + + Return array of size n, with entries 0 is lse_mode is true, otherwise entries of 1s + + Args: + ot_problem (LinearProblem): + lse_mode (bool, optional): Return log potentials if true. Defaults to True. + + Returns: + jnp.ndarray: potential f, array of size n + """ a = ot_problem.a init_dual_a = jnp.zeros_like(a) if lse_mode else jnp.ones_like(a) return init_dual_a - def default_dual_b(self, ot_problem, lse_mode): + def default_dual_b(self, ot_problem: LinearProblem, lse_mode : bool=True) -> jnp.ndarray: + """ + + Return array of size m, with entries 0 is lse_mode is true, otherwise entries of 1s + + Args: + ot_problem (LinearProblem): + lse_mode (bool, optional): Return log potentials if true. Defaults to True. + + Returns: + jnp.ndarray: potential fg array of size m + """ b = ot_problem.b init_dual_b = jnp.zeros_like(b) if lse_mode else jnp.ones_like(b) return init_dual_b @@ -75,7 +121,7 @@ def default_dual_b(self, ot_problem, lse_mode): class GaussianInitializer(SinkhornInitializer): def __init__(self, stop_gradient: Optional[bool] =True) -> None: - """_summary_ + """ Args: stop_gradient (bool, optional): _description_. Defaults to True. @@ -87,10 +133,10 @@ def __init__(self, stop_gradient: Optional[bool] =True) -> None: def init_dual_a(self, ot_problem: LinearProblem, init_f: Optional[jnp.ndarray] =None, lse_mode: bool = True) -> jnp.ndarray: - """_summary_ + """ Returns: - _type_: _description_ + jnp.ndarray: potential f, array of size n """ # import here due to circular imports from ott.tools.gaussian_mixture.gaussian import Gaussian @@ -110,7 +156,7 @@ def init_dual_a(self, ot_problem: LinearProblem, init_f: Optional[jnp.ndarray] = # Brenier potential for ground cost ||x-y||^2/2, so multiple by two for cost ||x-y||^2 f_potential = 2*gaussian_a.f_potential(dest=gaussian_b, points=x) f_potential = f_potential - jnp.mean(f_potential) - f_potential = f_potential if lse_mode else jnp.exp(f_potential) + f_potential = f_potential if lse_mode else ot_problem.scaling_from_potential(f_potential) return f_potential class SortingInit(SinkhornInitializer): @@ -135,12 +181,12 @@ def __init__(self, self.max_iter = max_iter self.update_fn = self.vectorized_update if vector_min else self.coordinate_update - def vectorized_update(self, f: jnp.ndarray, modified_cost: jnp.ndarray): - """_summary_ + def vectorized_update(self, f: jnp.ndarray, modified_cost: jnp.ndarray) -> jnp.ndarray: + """ Args: - f (jnp.ndarray): _description_ - modified_cost (jnp.ndarray): _description_ + f (jnp.ndarray): potential f, array of size n + modified_cost (jnp.ndarray): cost matrix minus diagonal of cost matrix across each column Returns: _type_: _description_ @@ -148,12 +194,12 @@ def vectorized_update(self, f: jnp.ndarray, modified_cost: jnp.ndarray): f = jnp.min(modified_cost + f[None, :], axis=1) return f - def coordinate_update(self, f: jnp.ndarray, modified_cost: jnp.ndarray): + def coordinate_update(self, f: jnp.ndarray, modified_cost: jnp.ndarray) -> jnp.ndarray: """_summary_ Args: - f (jnp.ndarray): _description_ - modified_cost (jnp.ndarray): _description_ + f (jnp.ndarray): potential f, array of size n + modified_cost (jnp.ndarray): cost matrix minus diagonal of cost matrix across each column """ def body_fn(i, f): @@ -163,15 +209,17 @@ def body_fn(i, f): return jax.lax.fori_loop(0, len(f), body_fn, f) - def init_sorting_dual(self, modified_cost: jnp.ndarray, f_potential: jnp.ndarray): - """_summary_ + def init_sorting_dual(self, modified_cost: jnp.ndarray, f_potential: jnp.ndarray) -> jnp.ndarray: + """ + + Run DualSort algorithm Args: - modified_cost (jnp.ndarray): _description_ - f_potential (jnp.ndarray): _description_ + modified_cost (jnp.ndarray): cost matrix minus diagonal of cost matrix across each column + f_potential (jnp.ndarray): potential f, array of size n Returns: - _type_: _description_ + jnp.ndarray: potential f, array of size n """ it = 0 diff = self.tolerance + 1.0 @@ -196,11 +244,11 @@ def init_dual_a(self, ot_problem: LinearProblem, init_f: jnp.ndarray = None, lse """ Args: - ot_problem (LinearProblem): _description_ - init_f (jnp.ndarray, optional): _description_. Defaults to None. + ot_problem (LinearProblem): OT problem + init_f (jnp.ndarray, optional): potential f, array of size n. Defaults to None. Returns: - jnp.ndarray: _description_ + jnp.ndarray: potential f, array of size n """ cost_matrix = ot_problem.geom.cost_matrix if self.stop_gradient: @@ -214,7 +262,7 @@ def init_dual_a(self, ot_problem: LinearProblem, init_f: jnp.ndarray = None, lse f_potential = self.init_sorting_dual(modified_cost, f_potential) f_potential = f_potential - jnp.mean(f_potential) - f_potential = f_potential if lse_mode else jnp.exp(f_potential) + f_potential = f_potential if lse_mode else ot_problem.scaling_from_potential(f_potential) return f_potential diff --git a/ott/core/sinkhorn.py b/ott/core/sinkhorn.py index b0c78e2e9..32f6151fa 100644 --- a/ott/core/sinkhorn.py +++ b/ott/core/sinkhorn.py @@ -417,7 +417,8 @@ def __call__( # Cancel dual variables for zero weights. init_dual_a, init_dual_b = self.potential_initializer.remove_null_weight_potentials(ot_problem=ot_prob, init_dual_a=init_dual_a, - init_dual_b=init_dual_b) + init_dual_b=init_dual_b, + lse_mode=self.lse_mode) run_fn = jax.jit(run) if self.jit else run return run_fn(ot_prob, self, (init_dual_a, init_dual_b)) @@ -591,7 +592,7 @@ def run( init: Tuple[jnp.ndarray, ...] ) -> SinkhornOutput: """Run loop of the solver, outputting a state upgraded to an output.""" - iter_fun = _iterations_implicit if solver.implicit_diff else iterations + iter_fun = _iterations_implicit if solver else iterations out = iter_fun(ot_prob, solver, init) # Be careful here, the geom and the cost are injected at the end, where it # does not interfere with the implicit differentiation. diff --git a/tests/core/continuous_barycenter_test.py b/tests/core/continuous_barycenter_test.py index 4290feac8..cb0dec4f9 100644 --- a/tests/core/continuous_barycenter_test.py +++ b/tests/core/continuous_barycenter_test.py @@ -13,7 +13,7 @@ # Lint as: python3 """Tests for Continuous barycenters.""" - +import sys import jax import jax.numpy as jnp from absl.testing import absltest, parameterized diff --git a/tests/core/initializers_test.py b/tests/core/initializers_test.py index 3bb7a5a99..90321ef90 100644 --- a/tests/core/initializers_test.py +++ b/tests/core/initializers_test.py @@ -21,11 +21,10 @@ from ott.core.sinkhorn import sinkhorn +from ott.geometry.geometry import Geometry from ott.geometry.pointcloud import PointCloud from ott.core import initializers as init_lib - - - +from ott.core.linear_problems import LinearProblem class InitializerTest(parameterized.TestCase): @@ -90,6 +89,52 @@ def run_sinkhorn(x, y, a=None, b=None, init_dual_a=None): # check initializer is better self.assertTrue(base_num_iter >= sort_num_iter) + def test_default_initializer(self): + """Tests default initializer""" + + # definte ot problem + np.random.seed(0) + n, d = 1000, 2 + mu_a = np.array([-1,1])*5 + mu_b = np.array([0,0]) + + + x = np.random.normal(size=n*d).reshape(n,d) + mu_a + y = np.random.normal(size=n*d).reshape(n,d) + mu_b + + + n = len(x) + m = len(y) + a = np.ones(n)/n + b = np.ones(m)/m + + x_jnp, y_jnp = jnp.array(x), jnp.array(y) + + gaus_init = init_lib.GaussianInitializer() + + geom_kwargs = {'epsilon': 0.01} + geom = PointCloud(x_jnp, y_jnp, **geom_kwargs) + + ot_problem = LinearProblem(geom=geom, a=a, b=b) + default_potential_a = gaus_init.default_dual_a(ot_problem=ot_problem) + default_potential_b = gaus_init.default_dual_b(ot_problem=ot_problem) + + # check default is 0 + self.assertTrue(( jnp.zeros(n) == default_potential_a).all()) + self.assertTrue(( jnp.zeros(m) == default_potential_b).all()) + + # check gausian init returns 0 for non point cloud geometry + new_geom = Geometry(cost_matrix=geom.cost_matrix, **geom_kwargs) + ot_problem = LinearProblem(geom=new_geom, a=a, b=b) + init_potential_a = gaus_init.init_dual_a(ot_problem=ot_problem) + init_potential_b = gaus_init.init_dual_a(ot_problem=ot_problem) + + self.assertTrue(( jnp.zeros(n) == init_potential_a).all()) + self.assertTrue(( jnp.zeros(m) == init_potential_b).all()) + + + + def test_gaus_initializer(self): """Tests Gaussian initializer""" diff --git a/tests/core/sinkhorn_test.py b/tests/core/sinkhorn_test.py index 396ecf5f0..c22e838ec 100644 --- a/tests/core/sinkhorn_test.py +++ b/tests/core/sinkhorn_test.py @@ -440,6 +440,17 @@ def test_restart(self, lse_mode): geom.scaling_from_potential(out.f), geom.scaling_from_potential(out.g) ) + + if lse_mode: + default_a = jnp.zeros_like(init_dual_a) + default_b = jnp.zeros_like(init_dual_b) + else: + default_a = jnp.ones_like(init_dual_a) + default_b = jnp.ones_like(init_dual_b) + + self.assertTrue(( default_a != init_dual_a).all()) + self.assertTrue(( default_b != init_dual_b).all()) + out_restarted = sinkhorn.sinkhorn( geom, a=self.a, @@ -450,15 +461,25 @@ def test_restart(self, lse_mode): init_dual_b=init_dual_b, inner_iterations=1 ) + errors_restarted = out_restarted.errors err_restarted = errors_restarted[errors_restarted > -1][-1] self.assertGreater(threshold, err_restarted) - + + num_iter_restarted = jnp.sum(errors_restarted > -1) + + # check we can only improve on error + # num_iter = jnp.sum(errors>-1) + # self.assertGreater(num_iter, num_iter_restarted) + # check we can only improve on error - self.assertGreater(err, err_restarted) - # check first error in restart does at least as well as previous best - self.assertGreater(err, errors_restarted[0]) + self.assertGreater(err+threshold, err_restarted) + + # # check first error in restart does at least as well as previous best + self.assertGreater(err+threshold, errors_restarted[2]) + self.assertGreater(err+threshold, errors_restarted[0]) + # check only one iteration suffices when restarting with same data. self.assertEqual(num_iter_restarted, 1) From 4f5fbd6296f6e5b146272f15d657742b74935713 Mon Sep 17 00:00:00 2001 From: James Thornton Date: Sun, 3 Jul 2022 16:26:29 -0700 Subject: [PATCH 16/46] fix flake8 and accidental removal --- ott/core/__init__.py | 4 +- ott/core/initializers.py | 495 ++++++++++++----------- ott/core/problems.py | 1 + ott/core/sinkhorn.py | 29 +- ott/tools/gaussian_mixture/gaussian.py | 26 +- ott/tools/gaussian_mixture/scale_tril.py | 2 +- ott/tools/transport.py | 2 +- tests/core/continuous_barycenter_test.py | 2 +- tests/core/initializers_test.py | 133 +++--- tests/core/sinkhorn_test.py | 23 +- 10 files changed, 363 insertions(+), 354 deletions(-) diff --git a/ott/core/__init__.py b/ott/core/__init__.py index d671c3529..67ef6898e 100644 --- a/ott/core/__init__.py +++ b/ott/core/__init__.py @@ -22,19 +22,19 @@ discrete_barycenter, gromov_wasserstein, implicit_differentiation, + initializers, linear_problems, momentum, quad_problems, sinkhorn, sinkhorn_lr, - initializers ) # from . import neuraldual from .implicit_differentiation import ImplicitDiff +from .initializers import SinkhornInitializer from .linear_problems import LinearProblem from .sinkhorn import Sinkhorn from .sinkhorn_lr import LRSinkhorn -from .initializers import SinkhornInitializer # pytype: enable=import-error # kwargs-checking diff --git a/ott/core/initializers.py b/ott/core/initializers.py index 8cf8ae313..66a5107e3 100644 --- a/ott/core/initializers.py +++ b/ott/core/initializers.py @@ -11,268 +11,275 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Sinkhorn initializers.""" -from ctypes import Union -import functools +from typing import Optional, Tuple + import jax from jax import numpy as jnp -from typing import Optional, Tuple from ott.core.linear_problems import LinearProblem from ott.geometry.pointcloud import PointCloud - class SinkhornInitializer(): - def init_dual_a(self, ot_problem: LinearProblem, lse_mode: bool = True) -> jnp.ndarray: - """ - - Initialzation for Sinkhorn potential f - - Args: - ot_problem (LinearProblem): OT problem between discrete distributions of size n and m - lse_mode (bool, optional): Return log potential. Defaults to True. - - Returns: - jnp.ndarray: dual potential, array of size n - """ - - return self.default_dual_a(ot_problem=ot_problem, lse_mode=lse_mode) - - - def init_dual_b(self, ot_problem: LinearProblem, lse_mode: bool = True) -> jnp.ndarray: - """ - - Initialzation for Sinkhorn potential g - - Args: - ot_problem (LinearProblem): OT problem between discrete distributions of size n and m - lse_mode (bool, optional): Return log potential. Defaults to True. - - Returns: - jnp.ndarray: dual potential, array of size m - """ - - return self.default_dual_b(ot_problem=ot_problem, lse_mode=lse_mode) - - def remove_null_weight_potentials(self, - ot_problem: LinearProblem, - init_dual_a: jnp.ndarray, - init_dual_b: jnp.ndarray, - lse_mode: bool=True) -> Tuple[jnp.ndarray]: - - """ - Cancel dual variables for zero weights. - - Args: - ot_problem (LinearProblem): - init_dual_a (jnp.ndarray): potential f, array of size n - init_dual_b (jnp.ndarray): potential g, array of size m - lse_mode (bool, optional): Return log potentials if true. Defaults to True. - - Returns: - Union[jnp.ndarray]: potentials (f,g) - """ - - a, b = ot_problem.a, ot_problem.b - init_dual_a = jnp.where( - a > 0, init_dual_a, -jnp.inf if lse_mode else 0.0 - ) - init_dual_b = jnp.where( - b > 0, init_dual_b, -jnp.inf if lse_mode else 0.0 - ) - return init_dual_a, init_dual_b - - def default_dual_a(self, ot_problem: LinearProblem, lse_mode: bool = True) -> jnp.ndarray: - """ - - Return array of size n, with entries 0 is lse_mode is true, otherwise entries of 1s - - Args: - ot_problem (LinearProblem): - lse_mode (bool, optional): Return log potentials if true. Defaults to True. - - Returns: - jnp.ndarray: potential f, array of size n - """ - a = ot_problem.a - init_dual_a = jnp.zeros_like(a) if lse_mode else jnp.ones_like(a) - return init_dual_a - - def default_dual_b(self, ot_problem: LinearProblem, lse_mode : bool=True) -> jnp.ndarray: - """ + def init_dual_a( + self, ot_problem: LinearProblem, lse_mode: bool = True + ) -> jnp.ndarray: + """ + Initialzation for Sinkhorn potential f + + Args: + ot_problem (LinearProblem): OT problem between discrete distributions of size n and m + lse_mode (bool, optional): Return log potential. Defaults to True. + + Returns: + jnp.ndarray: dual potential, array of size n + """ + + return self.default_dual_a(ot_problem=ot_problem, lse_mode=lse_mode) + + def init_dual_b( + self, ot_problem: LinearProblem, lse_mode: bool = True + ) -> jnp.ndarray: + """ + Initialzation for Sinkhorn potential g + + Args: + ot_problem (LinearProblem): OT problem between discrete distributions of size n and m + lse_mode (bool, optional): Return log potential. Defaults to True. + + Returns: + jnp.ndarray: dual potential, array of size m + """ + + return self.default_dual_b(ot_problem=ot_problem, lse_mode=lse_mode) + + def remove_null_weight_potentials( + self, + ot_problem: LinearProblem, + init_dual_a: jnp.ndarray, + init_dual_b: jnp.ndarray, + lse_mode: bool = True + ) -> Tuple[jnp.ndarray]: + """ + Cancel dual variables for zero weights. + + Args: + ot_problem (LinearProblem): + init_dual_a (jnp.ndarray): potential f, array of size n + init_dual_b (jnp.ndarray): potential g, array of size m + lse_mode (bool, optional): Return log potentials if true. Defaults to True. + + Returns: + Union[jnp.ndarray]: potentials (f,g) + """ + + a, b = ot_problem.a, ot_problem.b + init_dual_a = jnp.where(a > 0, init_dual_a, -jnp.inf if lse_mode else 0.0) + init_dual_b = jnp.where(b > 0, init_dual_b, -jnp.inf if lse_mode else 0.0) + return init_dual_a, init_dual_b + + def default_dual_a( + self, ot_problem: LinearProblem, lse_mode: bool = True + ) -> jnp.ndarray: + """ + Return array of size n, with entries 0 is lse_mode is true, otherwise entries of 1s + + Args: + ot_problem (LinearProblem): + lse_mode (bool, optional): Return log potentials if true. Defaults to True. + + Returns: + jnp.ndarray: potential f, array of size n + """ + a = ot_problem.a + init_dual_a = jnp.zeros_like(a) if lse_mode else jnp.ones_like(a) + return init_dual_a - Return array of size m, with entries 0 is lse_mode is true, otherwise entries of 1s + def default_dual_b( + self, ot_problem: LinearProblem, lse_mode: bool = True + ) -> jnp.ndarray: + """ + Return array of size m, with entries 0 is lse_mode is true, otherwise entries of 1s - Args: - ot_problem (LinearProblem): - lse_mode (bool, optional): Return log potentials if true. Defaults to True. + Args: + ot_problem (LinearProblem): + lse_mode (bool, optional): Return log potentials if true. Defaults to True. - Returns: - jnp.ndarray: potential fg array of size m - """ - b = ot_problem.b - init_dual_b = jnp.zeros_like(b) if lse_mode else jnp.ones_like(b) - return init_dual_b + Returns: + jnp.ndarray: potential fg array of size m + """ + b = ot_problem.b + init_dual_b = jnp.zeros_like(b) if lse_mode else jnp.ones_like(b) + return init_dual_b class GaussianInitializer(SinkhornInitializer): - def __init__(self, stop_gradient: Optional[bool] =True) -> None: - """ - - Args: - stop_gradient (bool, optional): _description_. Defaults to True. - """ - super().__init__() - - self.stop_gradient = stop_gradient + def __init__(self, stop_gradient: Optional[bool] = True) -> None: + """ + Args: + stop_gradient (bool, optional): _description_. Defaults to True. + """ + super().__init__() + + self.stop_gradient = stop_gradient + + def init_dual_a( + self, + ot_problem: LinearProblem, + init_f: Optional[jnp.ndarray] = None, + lse_mode: bool = True + ) -> jnp.ndarray: + """ + Returns: + jnp.ndarray: potential f, array of size n + """ + # import here due to circular imports + from ott.tools.gaussian_mixture.gaussian import Gaussian + + cost_matrix = ot_problem.geom.cost_matrix + if self.stop_gradient: + cost_matrix = jax.lax.stop_gradient(cost_matrix) + + if not isinstance(ot_problem.geom, PointCloud): + # warning that init not applied + return self.default_dual_a(ot_problem, lse_mode) + else: + x = ot_problem.geom.x + y = ot_problem.geom.y + gaussian_a = Gaussian.from_samples(x, weights=ot_problem.a) + gaussian_b = Gaussian.from_samples(y, weights=ot_problem.b) + # Brenier potential for ground cost ||x-y||^2/2, so multiple by two for cost ||x-y||^2 + f_potential = 2 * gaussian_a.f_potential(dest=gaussian_b, points=x) + f_potential = f_potential - jnp.mean(f_potential) + f_potential = f_potential if lse_mode else ot_problem.scaling_from_potential( + f_potential + ) + return f_potential - - def init_dual_a(self, ot_problem: LinearProblem, init_f: Optional[jnp.ndarray] =None, lse_mode: bool = True) -> jnp.ndarray: - - """ - - Returns: - jnp.ndarray: potential f, array of size n - """ - # import here due to circular imports - from ott.tools.gaussian_mixture.gaussian import Gaussian - - cost_matrix = ot_problem.geom.cost_matrix - if self.stop_gradient: - cost_matrix = jax.lax.stop_gradient(cost_matrix) - - if not isinstance(ot_problem.geom, PointCloud): - # warning that init not applied - return self.default_dual_a(ot_problem, lse_mode) - else: - x = ot_problem.geom.x - y = ot_problem.geom.y - gaussian_a = Gaussian.from_samples(x, weights=ot_problem.a) - gaussian_b = Gaussian.from_samples(y, weights=ot_problem.b) - # Brenier potential for ground cost ||x-y||^2/2, so multiple by two for cost ||x-y||^2 - f_potential = 2*gaussian_a.f_potential(dest=gaussian_b, points=x) - f_potential = f_potential - jnp.mean(f_potential) - f_potential = f_potential if lse_mode else ot_problem.scaling_from_potential(f_potential) - return f_potential class SortingInit(SinkhornInitializer): - def __init__(self, - vector_min: Optional[bool] = False, - tol: Optional[float] = 1e-2, - max_iter: Optional[int] = 100, - stop_gradient: Optional[bool] = True) -> None: - """_summary_ - - Args: - vector_min (Optional[bool], optional): _description_. Defaults to False. - tol (Optional[float], optional): _description_. Defaults to 1e-2. - max_iter (Optional[int], optional): _description_. Defaults to 100. - stop_gradient (Optional[bool], optional): _description_. Defaults to True. - """ - super().__init__() - - self.tolerance = tol - self.stop_gradient = stop_gradient - self.max_iter = max_iter - self.update_fn = self.vectorized_update if vector_min else self.coordinate_update - - def vectorized_update(self, f: jnp.ndarray, modified_cost: jnp.ndarray) -> jnp.ndarray: - """ - - Args: - f (jnp.ndarray): potential f, array of size n - modified_cost (jnp.ndarray): cost matrix minus diagonal of cost matrix across each column - - Returns: - _type_: _description_ - """ - f = jnp.min(modified_cost + f[None, :], axis=1) - return f - - def coordinate_update(self, f: jnp.ndarray, modified_cost: jnp.ndarray) -> jnp.ndarray: - """_summary_ - - Args: - f (jnp.ndarray): potential f, array of size n - modified_cost (jnp.ndarray): cost matrix minus diagonal of cost matrix across each column - """ - - def body_fn(i, f): - new_f = jnp.min(modified_cost[i, :] + f) - f = f.at[i].set(new_f) - return f - - return jax.lax.fori_loop(0, len(f), body_fn, f) - - def init_sorting_dual(self, modified_cost: jnp.ndarray, f_potential: jnp.ndarray) -> jnp.ndarray: - """ - - Run DualSort algorithm - - Args: - modified_cost (jnp.ndarray): cost matrix minus diagonal of cost matrix across each column - f_potential (jnp.ndarray): potential f, array of size n - - Returns: - jnp.ndarray: potential f, array of size n - """ - it = 0 - diff = self.tolerance + 1.0 - - state = (f_potential, diff, it) - def body_fn(state): - prev_f, _, it = state - f_potential = self.update_fn(prev_f, modified_cost) - diff = jnp.sum((f_potential - prev_f) ** 2) - it += 1 - return f_potential, diff, it - - def cond_fn(state): - _, diff, it = state - return (diff > self.tolerance) & (it < self.max_iter) - - f_potential, _, it = jax.lax.while_loop(cond_fun=cond_fn, body_fun=body_fn, init_val=state) - - return f_potential - - def init_dual_a(self, ot_problem: LinearProblem, init_f: jnp.ndarray = None, lse_mode: bool = True) -> jnp.ndarray: - """ - - Args: - ot_problem (LinearProblem): OT problem - init_f (jnp.ndarray, optional): potential f, array of size n. Defaults to None. - - Returns: - jnp.ndarray: potential f, array of size n - """ - cost_matrix = ot_problem.geom.cost_matrix - if self.stop_gradient: - cost_matrix = jax.lax.stop_gradient(cost_matrix) - - modified_cost = cost_matrix - jnp.diag(cost_matrix)[None, :] - - n = cost_matrix.shape[0] - f_potential = jnp.zeros(n) if init_f is None else init_f - - f_potential = self.init_sorting_dual(modified_cost, f_potential) - f_potential = f_potential - jnp.mean(f_potential) - - f_potential = f_potential if lse_mode else ot_problem.scaling_from_potential(f_potential) - - return f_potential - - - - - - - - - - - + def __init__( + self, + vector_min: Optional[bool] = False, + tol: Optional[float] = 1e-2, + max_iter: Optional[int] = 100, + stop_gradient: Optional[bool] = True + ) -> None: + """ + Args: + vector_min (Optional[bool], optional): _description_. Defaults to False. + tol (Optional[float], optional): _description_. Defaults to 1e-2. + max_iter (Optional[int], optional): _description_. Defaults to 100. + stop_gradient (Optional[bool], optional): _description_. Defaults to True. + """ + super().__init__() + + self.tolerance = tol + self.stop_gradient = stop_gradient + self.max_iter = max_iter + self.update_fn = self.vectorized_update if vector_min else self.coordinate_update + + def vectorized_update( + self, f: jnp.ndarray, modified_cost: jnp.ndarray + ) -> jnp.ndarray: + """ + Args: + f (jnp.ndarray): potential f, array of size n + modified_cost (jnp.ndarray): cost matrix minus diagonal of cost matrix across each column + + Returns: + jnp.ndarray: updated potential vector, f + """ + f = jnp.min(modified_cost + f[None, :], axis=1) + return f + + def coordinate_update( + self, f: jnp.ndarray, modified_cost: jnp.ndarray + ) -> jnp.ndarray: + """ + + Args: + f (jnp.ndarray): potential f, array of size n + modified_cost (jnp.ndarray): cost matrix minus diagonal of cost matrix across each column + + Returns: + jnp.ndarray: updated potential vector, f + """ + + def body_fn(i, f): + new_f = jnp.min(modified_cost[i, :] + f) + f = f.at[i].set(new_f) + return f + + return jax.lax.fori_loop(0, len(f), body_fn, f) + + def init_sorting_dual( + self, modified_cost: jnp.ndarray, f_potential: jnp.ndarray + ) -> jnp.ndarray: + """ + Run DualSort algorithm + + Args: + modified_cost (jnp.ndarray): cost matrix minus diagonal of cost matrix across each column + f_potential (jnp.ndarray): potential f, array of size n + + Returns: + jnp.ndarray: potential f, array of size n + """ + it = 0 + diff = self.tolerance + 1.0 + + state = (f_potential, diff, it) + + def body_fn(state): + prev_f, _, it = state + f_potential = self.update_fn(prev_f, modified_cost) + diff = jnp.sum((f_potential - prev_f) ** 2) + it += 1 + return f_potential, diff, it + + def cond_fn(state): + _, diff, it = state + return (diff > self.tolerance) & (it < self.max_iter) + + f_potential, _, it = jax.lax.while_loop( + cond_fun=cond_fn, body_fun=body_fn, init_val=state + ) + + return f_potential + + def init_dual_a( + self, + ot_problem: LinearProblem, + init_f: jnp.ndarray = None, + lse_mode: bool = True + ) -> jnp.ndarray: + """ + Args: + ot_problem (LinearProblem): OT problem + init_f (jnp.ndarray, optional): potential f, array of size n. Defaults to None. + + Returns: + jnp.ndarray: potential f, array of size n + """ + cost_matrix = ot_problem.geom.cost_matrix + if self.stop_gradient: + cost_matrix = jax.lax.stop_gradient(cost_matrix) + + modified_cost = cost_matrix - jnp.diag(cost_matrix)[None, :] + + n = cost_matrix.shape[0] + f_potential = jnp.zeros(n) if init_f is None else init_f + + f_potential = self.init_sorting_dual(modified_cost, f_potential) + f_potential = f_potential - jnp.mean(f_potential) + + f_potential = f_potential if lse_mode else ot_problem.scaling_from_potential( + f_potential + ) + + return f_potential diff --git a/ott/core/problems.py b/ott/core/problems.py index ce9ee328b..e60b4deb1 100644 --- a/ott/core/problems.py +++ b/ott/core/problems.py @@ -20,6 +20,7 @@ from ott.core import linear_problems, quad_problems from ott.geometry import geometry, pointcloud + def make( *args: Union[jnp.ndarray, geometry.Geometry, linear_problems.LinearProblem, quad_problems.QuadraticProblem], diff --git a/ott/core/sinkhorn.py b/ott/core/sinkhorn.py index 32f6151fa..85c3665bb 100644 --- a/ott/core/sinkhorn.py +++ b/ott/core/sinkhorn.py @@ -14,7 +14,6 @@ # Lint as: python3 """A Jax implementation of the Sinkhorn algorithm.""" -from pickle import NONE from typing import Any, Callable, NamedTuple, Optional, Sequence, Tuple import jax @@ -24,10 +23,10 @@ from ott.core import anderson as anderson_lib from ott.core import fixed_point_loop from ott.core import implicit_differentiation as implicit_lib +from ott.core import initializers as init_lib from ott.core import linear_problems from ott.core import momentum as momentum_lib from ott.core import unbalanced_functions -from ott.core import initializers as init_lib from ott.geometry import geometry @@ -351,7 +350,8 @@ def __init__( use_danskin: Optional[bool] = None, implicit_diff: Optional[implicit_lib.ImplicitDiff ] = implicit_lib.ImplicitDiff(), # noqa: E124 - potential_initializer: Optional[init_lib.SinkhornInitializer] = init_lib.SinkhornInitializer(), + potential_initializer: Optional[init_lib.SinkhornInitializer + ] = init_lib.SinkhornInitializer(), jit: bool = True ): self.lse_mode = lse_mode @@ -409,16 +409,22 @@ def __call__( init_dual_a, init_dual_b = (init if init is not None else (None, None)) if init_dual_a is None: - init_dual_a = self.potential_initializer.init_dual_a(ot_problem=ot_prob, lse_mode=self.lse_mode) + init_dual_a = self.potential_initializer.init_dual_a( + ot_problem=ot_prob, lse_mode=self.lse_mode + ) if init_dual_b is None: - init_dual_b = self.potential_initializer.init_dual_b(ot_problem=ot_prob, lse_mode=self.lse_mode) + init_dual_b = self.potential_initializer.init_dual_b( + ot_problem=ot_prob, lse_mode=self.lse_mode + ) # Cancel dual variables for zero weights. - init_dual_a, init_dual_b = self.potential_initializer.remove_null_weight_potentials(ot_problem=ot_prob, - init_dual_a=init_dual_a, - init_dual_b=init_dual_b, - lse_mode=self.lse_mode) + init_dual_a, init_dual_b = self.potential_initializer.remove_null_weight_potentials( + ot_problem=ot_prob, + init_dual_a=init_dual_a, + init_dual_b=init_dual_b, + lse_mode=self.lse_mode + ) run_fn = jax.jit(run) if self.jit else run return run_fn(ot_prob, self, (init_dual_a, init_dual_b)) @@ -592,7 +598,7 @@ def run( init: Tuple[jnp.ndarray, ...] ) -> SinkhornOutput: """Run loop of the solver, outputting a state upgraded to an output.""" - iter_fun = _iterations_implicit if solver else iterations + iter_fun = _iterations_implicit if solver.implicit_diff else iterations out = iter_fun(ot_prob, solver, init) # Be careful here, the geom and the cost are injected at the end, where it # does not interfere with the implicit differentiation. @@ -696,7 +702,8 @@ def make( precondition_fun: Optional[Callable[[float], float]] = None, parallel_dual_updates: bool = False, use_danskin: bool = None, - potential_initializer: Optional[init_lib.SinkhornInitializer] = init_lib.SinkhornInitializer(), + potential_initializer: Optional[init_lib.SinkhornInitializer + ] = init_lib.SinkhornInitializer(), jit: bool = False ) -> Sinkhorn: """For backward compatibility.""" diff --git a/ott/tools/gaussian_mixture/gaussian.py b/ott/tools/gaussian_mixture/gaussian.py index ed69bfdb3..267951068 100644 --- a/ott/tools/gaussian_mixture/gaussian.py +++ b/ott/tools/gaussian_mixture/gaussian.py @@ -23,9 +23,11 @@ LOG2PI = math.log(2. * math.pi) + @jax.vmap def batch_inner_product(x, y): - return x.dot(y) + return x.dot(y) + @jax.tree_util.register_pytree_node_class class Gaussian: @@ -34,9 +36,11 @@ class Gaussian: def __init__(self, loc: jnp.ndarray, scale: scale_tril.ScaleTriL): self._loc = loc self._scale = scale - + @classmethod - def from_samples(cls, x:jnp.ndarray, weights: jnp.ndarray = None) -> 'Gaussian': + def from_samples( + cls, x: jnp.ndarray, weights: jnp.ndarray = None + ) -> 'Gaussian': """Construct a Gaussian from weighted samples Args: @@ -49,7 +53,7 @@ def from_samples(cls, x:jnp.ndarray, weights: jnp.ndarray = None) -> 'Gaussian': if weights is None: n = x.shape[0] - weights = jnp.ones(n)/ n + weights = jnp.ones(n) / n mean = weights.dot(x) scaled_centered_x = (x - mean) * weights.reshape(-1, 1) @@ -64,7 +68,6 @@ def from_random( stdev: float = 0.1, dtype: Optional[jnp.dtype] = None ) -> 'Gaussian': - """Construct a random Gaussian. Args: @@ -163,15 +166,16 @@ def f_potential(self, dest: 'Gaussian', points: jnp.ndarray) -> jnp.ndarray: jnp.ndarray: _description_ """ scale_matrix = self.scale.transport_scale_matrix(dest_scale=dest.scale) - centered_x = points - self.loc - scaled_x = jnp.transpose(jnp.matmul(scale_matrix, jnp.transpose(centered_x))) + centered_x = points - self.loc + scaled_x = jnp.transpose( + jnp.matmul(scale_matrix, jnp.transpose(centered_x)) + ) return ( - 0.5 * batch_inner_product(points, points) - - 0.5 * batch_inner_product(centered_x, scaled_x) - - (points).dot(dest.loc) + 0.5 * batch_inner_product(points, points) - + 0.5 * batch_inner_product(centered_x, scaled_x) - + (points).dot(dest.loc) ) - def transport(self, dest: 'Gaussian', points: jnp.ndarray) -> jnp.ndarray: """_summary_ diff --git a/ott/tools/gaussian_mixture/scale_tril.py b/ott/tools/gaussian_mixture/scale_tril.py index 35cfa6af2..e1b49607f 100644 --- a/ott/tools/gaussian_mixture/scale_tril.py +++ b/ott/tools/gaussian_mixture/scale_tril.py @@ -160,7 +160,7 @@ def _flatten_cov(cov: jnp.ndarray) -> jnp.ndarray: def transport_scale_matrix(self, dest_scale: 'ScaleTriL') -> jnp.ndarray: """ Scaling matrix used in transport between 0-mean normal, mu, w/ current scale to one w/ dest_scale, nu - + m = Sigma_mu ^{-1/2} [ Sigma_mu ^{1/2} Sigma_nu Sigma_mu ^{1/2}] ^{1/2}Sigma_mu ^{-1/2} Args: diff --git a/ott/tools/transport.py b/ott/tools/transport.py index bc694202f..983ea8a75 100644 --- a/ott/tools/transport.py +++ b/ott/tools/transport.py @@ -121,7 +121,7 @@ def solve( linear = isinstance(pb, linear_problems.LinearProblem) solver_fn = sinkhorn.make if linear else gromov_wasserstein.make geom_keys = ['cost_fn', 'power', 'online'] - + init_dual_a = kwargs.get('init_dual_a', None) init_dual_b = kwargs.get('init_dual_b', None) init_keys = ['init_dual_a', 'init_dual_b'] diff --git a/tests/core/continuous_barycenter_test.py b/tests/core/continuous_barycenter_test.py index cb0dec4f9..4290feac8 100644 --- a/tests/core/continuous_barycenter_test.py +++ b/tests/core/continuous_barycenter_test.py @@ -13,7 +13,7 @@ # Lint as: python3 """Tests for Continuous barycenters.""" -import sys + import jax import jax.numpy as jnp from absl.testing import absltest, parameterized diff --git a/tests/core/initializers_test.py b/tests/core/initializers_test.py index 90321ef90..47bbcfa44 100644 --- a/tests/core/initializers_test.py +++ b/tests/core/initializers_test.py @@ -19,12 +19,11 @@ import numpy as np from absl.testing import absltest, parameterized - +from ott.core import initializers as init_lib +from ott.core.linear_problems import LinearProblem from ott.core.sinkhorn import sinkhorn from ott.geometry.geometry import Geometry from ott.geometry.pointcloud import PointCloud -from ott.core import initializers as init_lib -from ott.core.linear_problems import LinearProblem class InitializerTest(parameterized.TestCase): @@ -42,22 +41,24 @@ def test_sorting_init(self): # define sinkhorn functions @jax.jit def run_sinkhorn_sort_init(x, y, a=None, b=None, init_dual_a=None): - sink_kwargs = {'jit': True, - 'threshold': 0.001, - 'max_iterations': 10**5, - 'potential_initializer': sort_init} - geom_kwargs = {'epsilon': 0.01} - geom = PointCloud(x, y, **geom_kwargs) - out = sinkhorn(geom, a=a, b=b, init_dual_a=init_dual_a, **sink_kwargs) - return out + sink_kwargs = { + 'jit': True, + 'threshold': 0.001, + 'max_iterations': 10 ** 5, + 'potential_initializer': sort_init + } + geom_kwargs = {'epsilon': 0.01} + geom = PointCloud(x, y, **geom_kwargs) + out = sinkhorn(geom, a=a, b=b, init_dual_a=init_dual_a, **sink_kwargs) + return out @jax.jit def run_sinkhorn(x, y, a=None, b=None, init_dual_a=None): - sink_kwargs = {'jit': True, 'threshold': 0.001, 'max_iterations': 10**5} - geom_kwargs = {'epsilon': 0.01} - geom = PointCloud(x, y, **geom_kwargs) - out = sinkhorn(geom, a=a, b=b, init_dual_a=init_dual_a, **sink_kwargs) - return out + sink_kwargs = {'jit': True, 'threshold': 0.001, 'max_iterations': 10 ** 5} + geom_kwargs = {'epsilon': 0.01} + geom = PointCloud(x, y, **geom_kwargs) + out = sinkhorn(geom, a=a, b=b, init_dual_a=init_dual_a, **sink_kwargs) + return out # definte ot problem x_init = np.array([-1., 0, .22]) @@ -65,25 +66,24 @@ def run_sinkhorn(x, y, a=None, b=None, init_dual_a=None): buf = 500 np.random.seed(0) - x = np.concatenate([x_init, 10 + np.abs(np.random.normal(size=buf))])*5 - y = np.concatenate([y_init, 10 + np.abs(np.random.normal(size=buf))])*5 + x = np.concatenate([x_init, 10 + np.abs(np.random.normal(size=buf))]) * 5 + y = np.concatenate([y_init, 10 + np.abs(np.random.normal(size=buf))]) * 5 x = np.sort(x) y = np.sort(y) n = len(x) m = len(y) - a = np.ones(n)/n - b = np.ones(m)/m + a = np.ones(n) / n + b = np.ones(m) / m - x_jnp, y_jnp = jnp.array(x.reshape(-1,1)), jnp.array(y.reshape(-1,1)) + x_jnp, y_jnp = jnp.array(x.reshape(-1, 1)), jnp.array(y.reshape(-1, 1)) # run sinkhorn - sink_out = run_sinkhorn(x=x_jnp,y=y_jnp,a=a, b=b) + sink_out = run_sinkhorn(x=x_jnp, y=y_jnp, a=a, b=b) base_num_iter = jnp.sum(sink_out.errors > -1) - - sink_out = run_sinkhorn_sort_init(x=x_jnp,y=y_jnp, a=a, b=b) + sink_out = run_sinkhorn_sort_init(x=x_jnp, y=y_jnp, a=a, b=b) sort_num_iter = jnp.sum(sink_out.errors > -1) # check initializer is better @@ -95,18 +95,16 @@ def test_default_initializer(self): # definte ot problem np.random.seed(0) n, d = 1000, 2 - mu_a = np.array([-1,1])*5 - mu_b = np.array([0,0]) - - - x = np.random.normal(size=n*d).reshape(n,d) + mu_a - y = np.random.normal(size=n*d).reshape(n,d) + mu_b + mu_a = np.array([-1, 1]) * 5 + mu_b = np.array([0, 0]) + x = np.random.normal(size=n * d).reshape(n, d) + mu_a + y = np.random.normal(size=n * d).reshape(n, d) + mu_b n = len(x) m = len(y) - a = np.ones(n)/n - b = np.ones(m)/m + a = np.ones(n) / n + b = np.ones(m) / m x_jnp, y_jnp = jnp.array(x), jnp.array(y) @@ -120,74 +118,67 @@ def test_default_initializer(self): default_potential_b = gaus_init.default_dual_b(ot_problem=ot_problem) # check default is 0 - self.assertTrue(( jnp.zeros(n) == default_potential_a).all()) - self.assertTrue(( jnp.zeros(m) == default_potential_b).all()) - + self.assertTrue((jnp.zeros(n) == default_potential_a).all()) + self.assertTrue((jnp.zeros(m) == default_potential_b).all()) + # check gausian init returns 0 for non point cloud geometry - new_geom = Geometry(cost_matrix=geom.cost_matrix, **geom_kwargs) + new_geom = Geometry(cost_matrix=geom.cost_matrix, **geom_kwargs) ot_problem = LinearProblem(geom=new_geom, a=a, b=b) init_potential_a = gaus_init.init_dual_a(ot_problem=ot_problem) init_potential_b = gaus_init.init_dual_a(ot_problem=ot_problem) - self.assertTrue(( jnp.zeros(n) == init_potential_a).all()) - self.assertTrue(( jnp.zeros(m) == init_potential_b).all()) - - - - + self.assertTrue((jnp.zeros(n) == init_potential_a).all()) + self.assertTrue((jnp.zeros(m) == init_potential_b).all()) def test_gaus_initializer(self): """Tests Gaussian initializer""" - + # init initializer gaus_init = init_lib.GaussianInitializer() - @jax.jit def run_sinkhorn(x, y, a=None, b=None, init_dual_a=None): - sink_kwargs = {'jit': True, 'threshold': 0.001, 'max_iterations': 10**5} - geom_kwargs = {'epsilon': 0.01} - geom = PointCloud(x, y, **geom_kwargs) - out = sinkhorn(geom, a=a, b=b, init_dual_a=init_dual_a, **sink_kwargs) - return out - + sink_kwargs = {'jit': True, 'threshold': 0.001, 'max_iterations': 10 ** 5} + geom_kwargs = {'epsilon': 0.01} + geom = PointCloud(x, y, **geom_kwargs) + out = sinkhorn(geom, a=a, b=b, init_dual_a=init_dual_a, **sink_kwargs) + return out @jax.jit def run_sinkhorn_gaus_init(x, y, a=None, b=None, init_dual_a=None): - sink_kwargs = {'jit': True, - 'threshold': 0.001, - 'max_iterations': 10**5, - 'potential_initializer': gaus_init} - - geom_kwargs = {'epsilon': 0.01} - geom = PointCloud(x, y, **geom_kwargs) - out = sinkhorn(geom, a=a, b=b, init_dual_a=init_dual_a, **sink_kwargs) - return out - + sink_kwargs = { + 'jit': True, + 'threshold': 0.001, + 'max_iterations': 10 ** 5, + 'potential_initializer': gaus_init + } + + geom_kwargs = {'epsilon': 0.01} + geom = PointCloud(x, y, **geom_kwargs) + out = sinkhorn(geom, a=a, b=b, init_dual_a=init_dual_a, **sink_kwargs) + return out # definte ot problem np.random.seed(0) n, d = 1000, 2 - mu_a = np.array([-1,1])*5 - mu_b = np.array([0,0]) - - - x = np.random.normal(size=n*d).reshape(n,d) + mu_a - y = np.random.normal(size=n*d).reshape(n,d) + mu_b + mu_a = np.array([-1, 1]) * 5 + mu_b = np.array([0, 0]) + x = np.random.normal(size=n * d).reshape(n, d) + mu_a + y = np.random.normal(size=n * d).reshape(n, d) + mu_b n = len(x) m = len(y) - a = np.ones(n)/n - b = np.ones(m)/m + a = np.ones(n) / n + b = np.ones(m) / m x_jnp, y_jnp = jnp.array(x), jnp.array(y) # run sinkhorn - sink_out = run_sinkhorn(x=x_jnp,y=y_jnp,a=a, b=b) + sink_out = run_sinkhorn(x=x_jnp, y=y_jnp, a=a, b=b) base_num_iter = jnp.sum(sink_out.errors > -1) - sink_out = run_sinkhorn_gaus_init(x=x_jnp,y=y_jnp, a=a, b=b) + sink_out = run_sinkhorn_gaus_init(x=x_jnp, y=y_jnp, a=a, b=b) gaus_num_iter = jnp.sum(sink_out.errors > -1) # check initializer is better @@ -195,4 +186,4 @@ def run_sinkhorn_gaus_init(x, y, a=None, b=None, init_dual_a=None): if __name__ == '__main__': - absltest.main() \ No newline at end of file + absltest.main() diff --git a/tests/core/sinkhorn_test.py b/tests/core/sinkhorn_test.py index c22e838ec..1ebac575c 100644 --- a/tests/core/sinkhorn_test.py +++ b/tests/core/sinkhorn_test.py @@ -442,15 +442,15 @@ def test_restart(self, lse_mode): ) if lse_mode: - default_a = jnp.zeros_like(init_dual_a) - default_b = jnp.zeros_like(init_dual_b) + default_a = jnp.zeros_like(init_dual_a) + default_b = jnp.zeros_like(init_dual_b) else: - default_a = jnp.ones_like(init_dual_a) - default_b = jnp.ones_like(init_dual_b) + default_a = jnp.ones_like(init_dual_a) + default_b = jnp.ones_like(init_dual_b) + + self.assertTrue((default_a != init_dual_a).all()) + self.assertTrue((default_b != init_dual_b).all()) - self.assertTrue(( default_a != init_dual_a).all()) - self.assertTrue(( default_b != init_dual_b).all()) - out_restarted = sinkhorn.sinkhorn( geom, a=self.a, @@ -465,8 +465,7 @@ def test_restart(self, lse_mode): errors_restarted = out_restarted.errors err_restarted = errors_restarted[errors_restarted > -1][-1] self.assertGreater(threshold, err_restarted) - - + num_iter_restarted = jnp.sum(errors_restarted > -1) # check we can only improve on error @@ -474,11 +473,11 @@ def test_restart(self, lse_mode): # self.assertGreater(num_iter, num_iter_restarted) # check we can only improve on error - self.assertGreater(err+threshold, err_restarted) + self.assertGreater(err + threshold, err_restarted) # # check first error in restart does at least as well as previous best - self.assertGreater(err+threshold, errors_restarted[2]) - self.assertGreater(err+threshold, errors_restarted[0]) + self.assertGreater(err + threshold, errors_restarted[2]) + self.assertGreater(err + threshold, errors_restarted[0]) # check only one iteration suffices when restarting with same data. self.assertEqual(num_iter_restarted, 1) From 9b0f2247b1e3e18d7e1bc3f92795fda29f2a335a Mon Sep 17 00:00:00 2001 From: James Thornton Date: Sun, 3 Jul 2022 16:36:01 -0700 Subject: [PATCH 17/46] tidy docstrings --- ott/core/initializers.py | 49 ++++++++++++++++++++-------------------- 1 file changed, 25 insertions(+), 24 deletions(-) diff --git a/ott/core/initializers.py b/ott/core/initializers.py index 66a5107e3..3d9551c66 100644 --- a/ott/core/initializers.py +++ b/ott/core/initializers.py @@ -26,8 +26,7 @@ class SinkhornInitializer(): def init_dual_a( self, ot_problem: LinearProblem, lse_mode: bool = True ) -> jnp.ndarray: - """ - Initialzation for Sinkhorn potential f + """ Initialzation for Sinkhorn potential f. Args: ot_problem (LinearProblem): OT problem between discrete distributions of size n and m @@ -42,8 +41,7 @@ def init_dual_a( def init_dual_b( self, ot_problem: LinearProblem, lse_mode: bool = True ) -> jnp.ndarray: - """ - Initialzation for Sinkhorn potential g + """ Initialzation for Sinkhorn potential g. Args: ot_problem (LinearProblem): OT problem between discrete distributions of size n and m @@ -62,8 +60,7 @@ def remove_null_weight_potentials( init_dual_b: jnp.ndarray, lse_mode: bool = True ) -> Tuple[jnp.ndarray]: - """ - Cancel dual variables for zero weights. + """ Cancel dual variables for zero weights. Args: ot_problem (LinearProblem): @@ -83,8 +80,7 @@ def remove_null_weight_potentials( def default_dual_a( self, ot_problem: LinearProblem, lse_mode: bool = True ) -> jnp.ndarray: - """ - Return array of size n, with entries 0 is lse_mode is true, otherwise entries of 1s + """ Return array of size n, with entries 0 is lse_mode is true, otherwise entries of 1s. Args: ot_problem (LinearProblem): @@ -100,8 +96,7 @@ def default_dual_a( def default_dual_b( self, ot_problem: LinearProblem, lse_mode: bool = True ) -> jnp.ndarray: - """ - Return array of size m, with entries 0 is lse_mode is true, otherwise entries of 1s + """ Return array of size m, with entries 0 is lse_mode is true, otherwise entries of 1s. Args: ot_problem (LinearProblem): @@ -118,7 +113,8 @@ def default_dual_b( class GaussianInitializer(SinkhornInitializer): def __init__(self, stop_gradient: Optional[bool] = True) -> None: - """ + """ GaussianInitializer. + Args: stop_gradient (bool, optional): _description_. Defaults to True. """ @@ -132,9 +128,15 @@ def init_dual_a( init_f: Optional[jnp.ndarray] = None, lse_mode: bool = True ) -> jnp.ndarray: - """ - Returns: - jnp.ndarray: potential f, array of size n + """ Gaussian init function. + + Args: + ot_problem (LinearProblem): OT problem description with geometry and weights. + init_f (Optional[jnp.ndarray], optional): Pre dual sort initialization, when none sets entries as 0 + lse_mode (bool, optional): Return log potential if true. Defaults to True. + + Returns: + jnp.ndarray: jnp.ndarray: potential f, array of size n """ # import here due to circular imports from ott.tools.gaussian_mixture.gaussian import Gaussian @@ -169,12 +171,12 @@ def __init__( max_iter: Optional[int] = 100, stop_gradient: Optional[bool] = True ) -> None: - """ + """ Sorting Init class. Args: - vector_min (Optional[bool], optional): _description_. Defaults to False. - tol (Optional[float], optional): _description_. Defaults to 1e-2. - max_iter (Optional[int], optional): _description_. Defaults to 100. - stop_gradient (Optional[bool], optional): _description_. Defaults to True. + vector_min (Optional[bool], optional): Use vectorized inner loop if true. Defaults to False. + tol (Optional[float], optional): DualSort convergence threshold. Defaults to 1e-2. + max_iter (Optional[int], optional): Max DualSort steps. Defaults to 100. + stop_gradient (Optional[bool], optional): Do not trace gradient through the initializer. Defaults to True. """ super().__init__() @@ -186,7 +188,7 @@ def __init__( def vectorized_update( self, f: jnp.ndarray, modified_cost: jnp.ndarray ) -> jnp.ndarray: - """ + """ Inner loop DualSort Update. Args: f (jnp.ndarray): potential f, array of size n modified_cost (jnp.ndarray): cost matrix minus diagonal of cost matrix across each column @@ -200,7 +202,7 @@ def vectorized_update( def coordinate_update( self, f: jnp.ndarray, modified_cost: jnp.ndarray ) -> jnp.ndarray: - """ + """ Coordinate-wise updates within inner loop. Args: f (jnp.ndarray): potential f, array of size n @@ -220,8 +222,7 @@ def body_fn(i, f): def init_sorting_dual( self, modified_cost: jnp.ndarray, f_potential: jnp.ndarray ) -> jnp.ndarray: - """ - Run DualSort algorithm + """ Run DualSort algorithm. Args: modified_cost (jnp.ndarray): cost matrix minus diagonal of cost matrix across each column @@ -258,7 +259,7 @@ def init_dual_a( init_f: jnp.ndarray = None, lse_mode: bool = True ) -> jnp.ndarray: - """ + """ Apply DualSort algo. Args: ot_problem (LinearProblem): OT problem init_f (jnp.ndarray, optional): potential f, array of size n. Defaults to None. From d33d89c9b104c1d4d8f408e2ee57f331a57f7032 Mon Sep 17 00:00:00 2001 From: James Thornton Date: Sun, 3 Jul 2022 16:38:19 -0700 Subject: [PATCH 18/46] tidy docstrings --- ott/core/initializers.py | 26 ++------------------------ 1 file changed, 2 insertions(+), 24 deletions(-) diff --git a/ott/core/initializers.py b/ott/core/initializers.py index 3d9551c66..aa04a18cc 100644 --- a/ott/core/initializers.py +++ b/ott/core/initializers.py @@ -27,30 +27,24 @@ def init_dual_a( self, ot_problem: LinearProblem, lse_mode: bool = True ) -> jnp.ndarray: """ Initialzation for Sinkhorn potential f. - Args: ot_problem (LinearProblem): OT problem between discrete distributions of size n and m lse_mode (bool, optional): Return log potential. Defaults to True. - Returns: jnp.ndarray: dual potential, array of size n """ - return self.default_dual_a(ot_problem=ot_problem, lse_mode=lse_mode) def init_dual_b( self, ot_problem: LinearProblem, lse_mode: bool = True ) -> jnp.ndarray: """ Initialzation for Sinkhorn potential g. - Args: ot_problem (LinearProblem): OT problem between discrete distributions of size n and m lse_mode (bool, optional): Return log potential. Defaults to True. - Returns: jnp.ndarray: dual potential, array of size m """ - return self.default_dual_b(ot_problem=ot_problem, lse_mode=lse_mode) def remove_null_weight_potentials( @@ -61,13 +55,11 @@ def remove_null_weight_potentials( lse_mode: bool = True ) -> Tuple[jnp.ndarray]: """ Cancel dual variables for zero weights. - Args: ot_problem (LinearProblem): init_dual_a (jnp.ndarray): potential f, array of size n init_dual_b (jnp.ndarray): potential g, array of size m lse_mode (bool, optional): Return log potentials if true. Defaults to True. - Returns: Union[jnp.ndarray]: potentials (f,g) """ @@ -81,11 +73,9 @@ def default_dual_a( self, ot_problem: LinearProblem, lse_mode: bool = True ) -> jnp.ndarray: """ Return array of size n, with entries 0 is lse_mode is true, otherwise entries of 1s. - Args: ot_problem (LinearProblem): lse_mode (bool, optional): Return log potentials if true. Defaults to True. - Returns: jnp.ndarray: potential f, array of size n """ @@ -97,11 +87,9 @@ def default_dual_b( self, ot_problem: LinearProblem, lse_mode: bool = True ) -> jnp.ndarray: """ Return array of size m, with entries 0 is lse_mode is true, otherwise entries of 1s. - Args: ot_problem (LinearProblem): lse_mode (bool, optional): Return log potentials if true. Defaults to True. - Returns: jnp.ndarray: potential fg array of size m """ @@ -114,7 +102,6 @@ class GaussianInitializer(SinkhornInitializer): def __init__(self, stop_gradient: Optional[bool] = True) -> None: """ GaussianInitializer. - Args: stop_gradient (bool, optional): _description_. Defaults to True. """ @@ -129,12 +116,10 @@ def init_dual_a( lse_mode: bool = True ) -> jnp.ndarray: """ Gaussian init function. - Args: ot_problem (LinearProblem): OT problem description with geometry and weights. init_f (Optional[jnp.ndarray], optional): Pre dual sort initialization, when none sets entries as 0 lse_mode (bool, optional): Return log potential if true. Defaults to True. - Returns: jnp.ndarray: jnp.ndarray: potential f, array of size n """ @@ -190,9 +175,8 @@ def vectorized_update( ) -> jnp.ndarray: """ Inner loop DualSort Update. Args: - f (jnp.ndarray): potential f, array of size n - modified_cost (jnp.ndarray): cost matrix minus diagonal of cost matrix across each column - + f (jnp.ndarray): potential f, array of size n + modified_cost (jnp.ndarray): cost matrix minus diagonal of cost matrix across each column Returns: jnp.ndarray: updated potential vector, f """ @@ -203,15 +187,12 @@ def coordinate_update( self, f: jnp.ndarray, modified_cost: jnp.ndarray ) -> jnp.ndarray: """ Coordinate-wise updates within inner loop. - Args: f (jnp.ndarray): potential f, array of size n modified_cost (jnp.ndarray): cost matrix minus diagonal of cost matrix across each column - Returns: jnp.ndarray: updated potential vector, f """ - def body_fn(i, f): new_f = jnp.min(modified_cost[i, :] + f) f = f.at[i].set(new_f) @@ -223,11 +204,9 @@ def init_sorting_dual( self, modified_cost: jnp.ndarray, f_potential: jnp.ndarray ) -> jnp.ndarray: """ Run DualSort algorithm. - Args: modified_cost (jnp.ndarray): cost matrix minus diagonal of cost matrix across each column f_potential (jnp.ndarray): potential f, array of size n - Returns: jnp.ndarray: potential f, array of size n """ @@ -263,7 +242,6 @@ def init_dual_a( Args: ot_problem (LinearProblem): OT problem init_f (jnp.ndarray, optional): potential f, array of size n. Defaults to None. - Returns: jnp.ndarray: potential f, array of size n """ From d83f9133e2571372051d74983e0a5c00cb2439e0 Mon Sep 17 00:00:00 2001 From: James Thornton Date: Sun, 3 Jul 2022 16:49:14 -0700 Subject: [PATCH 19/46] docstring flake8 --- ott/core/initializers.py | 31 +++++++++++++++++------- ott/core/sinkhorn.py | 5 ++-- ott/tools/gaussian_mixture/scale_tril.py | 6 ++--- 3 files changed, 26 insertions(+), 16 deletions(-) diff --git a/ott/core/initializers.py b/ott/core/initializers.py index aa04a18cc..87add48d8 100644 --- a/ott/core/initializers.py +++ b/ott/core/initializers.py @@ -27,8 +27,9 @@ def init_dual_a( self, ot_problem: LinearProblem, lse_mode: bool = True ) -> jnp.ndarray: """ Initialzation for Sinkhorn potential f. + Args: - ot_problem (LinearProblem): OT problem between discrete distributions of size n and m + ot_problem (LinearProblem): OT problem between discrete distributions of size n and m. lse_mode (bool, optional): Return log potential. Defaults to True. Returns: jnp.ndarray: dual potential, array of size n @@ -39,6 +40,7 @@ def init_dual_b( self, ot_problem: LinearProblem, lse_mode: bool = True ) -> jnp.ndarray: """ Initialzation for Sinkhorn potential g. + Args: ot_problem (LinearProblem): OT problem between discrete distributions of size n and m lse_mode (bool, optional): Return log potential. Defaults to True. @@ -55,6 +57,7 @@ def remove_null_weight_potentials( lse_mode: bool = True ) -> Tuple[jnp.ndarray]: """ Cancel dual variables for zero weights. + Args: ot_problem (LinearProblem): init_dual_a (jnp.ndarray): potential f, array of size n @@ -73,11 +76,12 @@ def default_dual_a( self, ot_problem: LinearProblem, lse_mode: bool = True ) -> jnp.ndarray: """ Return array of size n, with entries 0 is lse_mode is true, otherwise entries of 1s. + Args: ot_problem (LinearProblem): lse_mode (bool, optional): Return log potentials if true. Defaults to True. Returns: - jnp.ndarray: potential f, array of size n + jnp.ndarray: potential f, array of size n """ a = ot_problem.a init_dual_a = jnp.zeros_like(a) if lse_mode else jnp.ones_like(a) @@ -87,11 +91,12 @@ def default_dual_b( self, ot_problem: LinearProblem, lse_mode: bool = True ) -> jnp.ndarray: """ Return array of size m, with entries 0 is lse_mode is true, otherwise entries of 1s. + Args: ot_problem (LinearProblem): lse_mode (bool, optional): Return log potentials if true. Defaults to True. Returns: - jnp.ndarray: potential fg array of size m + jnp.ndarray: potential fg array of size m """ b = ot_problem.b init_dual_b = jnp.zeros_like(b) if lse_mode else jnp.ones_like(b) @@ -102,6 +107,7 @@ class GaussianInitializer(SinkhornInitializer): def __init__(self, stop_gradient: Optional[bool] = True) -> None: """ GaussianInitializer. + Args: stop_gradient (bool, optional): _description_. Defaults to True. """ @@ -116,12 +122,13 @@ def init_dual_a( lse_mode: bool = True ) -> jnp.ndarray: """ Gaussian init function. - Args: - ot_problem (LinearProblem): OT problem description with geometry and weights. - init_f (Optional[jnp.ndarray], optional): Pre dual sort initialization, when none sets entries as 0 - lse_mode (bool, optional): Return log potential if true. Defaults to True. - Returns: - jnp.ndarray: jnp.ndarray: potential f, array of size n + + Args: + ot_problem (LinearProblem): OT problem description with geometry and weights. + init_f (Optional[jnp.ndarray], optional): Pre dual sort initialization, when none sets entries as 0. + lse_mode (bool, optional): Return log potential if true. Defaults to True. + Returns: + jnp.ndarray: jnp.ndarray: potential f, array of size n. """ # import here due to circular imports from ott.tools.gaussian_mixture.gaussian import Gaussian @@ -157,6 +164,7 @@ def __init__( stop_gradient: Optional[bool] = True ) -> None: """ Sorting Init class. + Args: vector_min (Optional[bool], optional): Use vectorized inner loop if true. Defaults to False. tol (Optional[float], optional): DualSort convergence threshold. Defaults to 1e-2. @@ -174,6 +182,7 @@ def vectorized_update( self, f: jnp.ndarray, modified_cost: jnp.ndarray ) -> jnp.ndarray: """ Inner loop DualSort Update. + Args: f (jnp.ndarray): potential f, array of size n modified_cost (jnp.ndarray): cost matrix minus diagonal of cost matrix across each column @@ -187,12 +196,14 @@ def coordinate_update( self, f: jnp.ndarray, modified_cost: jnp.ndarray ) -> jnp.ndarray: """ Coordinate-wise updates within inner loop. + Args: f (jnp.ndarray): potential f, array of size n modified_cost (jnp.ndarray): cost matrix minus diagonal of cost matrix across each column Returns: jnp.ndarray: updated potential vector, f """ + def body_fn(i, f): new_f = jnp.min(modified_cost[i, :] + f) f = f.at[i].set(new_f) @@ -204,6 +215,7 @@ def init_sorting_dual( self, modified_cost: jnp.ndarray, f_potential: jnp.ndarray ) -> jnp.ndarray: """ Run DualSort algorithm. + Args: modified_cost (jnp.ndarray): cost matrix minus diagonal of cost matrix across each column f_potential (jnp.ndarray): potential f, array of size n @@ -239,6 +251,7 @@ def init_dual_a( lse_mode: bool = True ) -> jnp.ndarray: """ Apply DualSort algo. + Args: ot_problem (LinearProblem): OT problem init_f (jnp.ndarray, optional): potential f, array of size n. Defaults to None. diff --git a/ott/core/sinkhorn.py b/ott/core/sinkhorn.py index 85c3665bb..c108e81d4 100644 --- a/ott/core/sinkhorn.py +++ b/ott/core/sinkhorn.py @@ -404,7 +404,6 @@ def __call__( init: Optional[Tuple[Optional[jnp.ndarray], Optional[jnp.ndarray]]] = None ) -> SinkhornOutput: """Main interface to run sinkhorn.""" # noqa: D401 - # initialization init_dual_a, init_dual_b = (init if init is not None else (None, None)) @@ -702,8 +701,8 @@ def make( precondition_fun: Optional[Callable[[float], float]] = None, parallel_dual_updates: bool = False, use_danskin: bool = None, - potential_initializer: Optional[init_lib.SinkhornInitializer - ] = init_lib.SinkhornInitializer(), + potential_initializer: init_lib.SinkhornInitializer = init_lib + .SinkhornInitializer(), jit: bool = False ) -> Sinkhorn: """For backward compatibility.""" diff --git a/ott/tools/gaussian_mixture/scale_tril.py b/ott/tools/gaussian_mixture/scale_tril.py index e1b49607f..3a4683405 100644 --- a/ott/tools/gaussian_mixture/scale_tril.py +++ b/ott/tools/gaussian_mixture/scale_tril.py @@ -158,10 +158,8 @@ def _flatten_cov(cov: jnp.ndarray) -> jnp.ndarray: cost_fn.pairwise(x0, x1))[...,] def transport_scale_matrix(self, dest_scale: 'ScaleTriL') -> jnp.ndarray: - """ - Scaling matrix used in transport between 0-mean normal, mu, w/ current scale to one w/ dest_scale, nu - - m = Sigma_mu ^{-1/2} [ Sigma_mu ^{1/2} Sigma_nu Sigma_mu ^{1/2}] ^{1/2}Sigma_mu ^{-1/2} + """ Scaling matrix used in transport between 0-mean normal, mu, w/ current scale to one w/ dest_scale, nu. + m = Sigma_mu ^{-1/2} [ Sigma_mu ^{1/2} Sigma_nu Sigma_mu ^{1/2}] ^{1/2}Sigma_mu ^{-1/2}. Args: dest_scale: destination Scale From 30455a7ebd5e8a0b49f56c091e978a25a10f03b5 Mon Sep 17 00:00:00 2001 From: James Thornton Date: Mon, 4 Jul 2022 02:33:10 -0700 Subject: [PATCH 20/46] flake 8 formatting --- ott/core/initializers.py | 29 ++++++++++++------------ ott/core/sinkhorn.py | 6 ++--- ott/tools/gaussian_mixture/gaussian.py | 7 +++--- ott/tools/gaussian_mixture/scale_tril.py | 7 ++---- 4 files changed, 22 insertions(+), 27 deletions(-) diff --git a/ott/core/initializers.py b/ott/core/initializers.py index 87add48d8..67bc76638 100644 --- a/ott/core/initializers.py +++ b/ott/core/initializers.py @@ -26,23 +26,23 @@ class SinkhornInitializer(): def init_dual_a( self, ot_problem: LinearProblem, lse_mode: bool = True ) -> jnp.ndarray: - """ Initialzation for Sinkhorn potential f. + """Initialzation for Sinkhorn potential f. Args: ot_problem (LinearProblem): OT problem between discrete distributions of size n and m. lse_mode (bool, optional): Return log potential. Defaults to True. Returns: - jnp.ndarray: dual potential, array of size n + jnp.ndarray: dual potential, array of size n """ return self.default_dual_a(ot_problem=ot_problem, lse_mode=lse_mode) def init_dual_b( self, ot_problem: LinearProblem, lse_mode: bool = True ) -> jnp.ndarray: - """ Initialzation for Sinkhorn potential g. + """Initialzation for Sinkhorn potential g. Args: - ot_problem (LinearProblem): OT problem between discrete distributions of size n and m + ot_problem (LinearProblem): OT problem between discrete distributions of size n and m. lse_mode (bool, optional): Return log potential. Defaults to True. Returns: jnp.ndarray: dual potential, array of size m @@ -56,7 +56,7 @@ def remove_null_weight_potentials( init_dual_b: jnp.ndarray, lse_mode: bool = True ) -> Tuple[jnp.ndarray]: - """ Cancel dual variables for zero weights. + """Cancel dual variables for zero weights. Args: ot_problem (LinearProblem): @@ -66,7 +66,6 @@ def remove_null_weight_potentials( Returns: Union[jnp.ndarray]: potentials (f,g) """ - a, b = ot_problem.a, ot_problem.b init_dual_a = jnp.where(a > 0, init_dual_a, -jnp.inf if lse_mode else 0.0) init_dual_b = jnp.where(b > 0, init_dual_b, -jnp.inf if lse_mode else 0.0) @@ -75,7 +74,7 @@ def remove_null_weight_potentials( def default_dual_a( self, ot_problem: LinearProblem, lse_mode: bool = True ) -> jnp.ndarray: - """ Return array of size n, with entries 0 is lse_mode is true, otherwise entries of 1s. + """Return array of size n, with entries 0 is lse_mode is true, otherwise entries of 1s. Args: ot_problem (LinearProblem): @@ -90,7 +89,7 @@ def default_dual_a( def default_dual_b( self, ot_problem: LinearProblem, lse_mode: bool = True ) -> jnp.ndarray: - """ Return array of size m, with entries 0 is lse_mode is true, otherwise entries of 1s. + """Return array of size m, with entries 0 is lse_mode is true, otherwise entries of 1s. Args: ot_problem (LinearProblem): @@ -106,7 +105,7 @@ def default_dual_b( class GaussianInitializer(SinkhornInitializer): def __init__(self, stop_gradient: Optional[bool] = True) -> None: - """ GaussianInitializer. + """GaussianInitializer. Args: stop_gradient (bool, optional): _description_. Defaults to True. @@ -121,7 +120,7 @@ def init_dual_a( init_f: Optional[jnp.ndarray] = None, lse_mode: bool = True ) -> jnp.ndarray: - """ Gaussian init function. + """Gaussian init function. Args: ot_problem (LinearProblem): OT problem description with geometry and weights. @@ -163,7 +162,7 @@ def __init__( max_iter: Optional[int] = 100, stop_gradient: Optional[bool] = True ) -> None: - """ Sorting Init class. + """Sorting Init class. Args: vector_min (Optional[bool], optional): Use vectorized inner loop if true. Defaults to False. @@ -181,7 +180,7 @@ def __init__( def vectorized_update( self, f: jnp.ndarray, modified_cost: jnp.ndarray ) -> jnp.ndarray: - """ Inner loop DualSort Update. + """Inner loop DualSort Update. Args: f (jnp.ndarray): potential f, array of size n @@ -195,7 +194,7 @@ def vectorized_update( def coordinate_update( self, f: jnp.ndarray, modified_cost: jnp.ndarray ) -> jnp.ndarray: - """ Coordinate-wise updates within inner loop. + """Coordinate-wise updates within inner loop. Args: f (jnp.ndarray): potential f, array of size n @@ -214,7 +213,7 @@ def body_fn(i, f): def init_sorting_dual( self, modified_cost: jnp.ndarray, f_potential: jnp.ndarray ) -> jnp.ndarray: - """ Run DualSort algorithm. + """Run DualSort algorithm. Args: modified_cost (jnp.ndarray): cost matrix minus diagonal of cost matrix across each column @@ -250,7 +249,7 @@ def init_dual_a( init_f: jnp.ndarray = None, lse_mode: bool = True ) -> jnp.ndarray: - """ Apply DualSort algo. + """Apply DualSort algo. Args: ot_problem (LinearProblem): OT problem diff --git a/ott/core/sinkhorn.py b/ott/core/sinkhorn.py index c108e81d4..fe08f1bde 100644 --- a/ott/core/sinkhorn.py +++ b/ott/core/sinkhorn.py @@ -350,8 +350,8 @@ def __init__( use_danskin: Optional[bool] = None, implicit_diff: Optional[implicit_lib.ImplicitDiff ] = implicit_lib.ImplicitDiff(), # noqa: E124 - potential_initializer: Optional[init_lib.SinkhornInitializer - ] = init_lib.SinkhornInitializer(), + potential_initializer: init_lib.SinkhornInitializer = init_lib + .SinkhornInitializer(), jit: bool = True ): self.lse_mode = lse_mode @@ -404,7 +404,7 @@ def __call__( init: Optional[Tuple[Optional[jnp.ndarray], Optional[jnp.ndarray]]] = None ) -> SinkhornOutput: """Main interface to run sinkhorn.""" # noqa: D401 - # initialization + # initializationgit s init_dual_a, init_dual_b = (init if init is not None else (None, None)) if init_dual_a is None: diff --git a/ott/tools/gaussian_mixture/gaussian.py b/ott/tools/gaussian_mixture/gaussian.py index 267951068..c0c9fbd46 100644 --- a/ott/tools/gaussian_mixture/gaussian.py +++ b/ott/tools/gaussian_mixture/gaussian.py @@ -41,7 +41,7 @@ def __init__(self, loc: jnp.ndarray, scale: scale_tril.ScaleTriL): def from_samples( cls, x: jnp.ndarray, weights: jnp.ndarray = None ) -> 'Gaussian': - """Construct a Gaussian from weighted samples + """Construct a Gaussian from weighted samples. Args: x: [n x d] array of samples @@ -50,7 +50,6 @@ def from_samples( Returns: Gaussian. """ - if weights is None: n = x.shape[0] weights = jnp.ones(n) / n @@ -156,7 +155,7 @@ def w2_dist(self, other: 'Gaussian') -> jnp.ndarray: return delta_mean + delta_sigma def f_potential(self, dest: 'Gaussian', points: jnp.ndarray) -> jnp.ndarray: - """_summary_ + """W2 distance between Gaussians. Args: dest (Gaussian): _description_ @@ -177,7 +176,7 @@ def f_potential(self, dest: 'Gaussian', points: jnp.ndarray) -> jnp.ndarray: ) def transport(self, dest: 'Gaussian', points: jnp.ndarray) -> jnp.ndarray: - """_summary_ + """Transport Gaussian. Args: dest (Gaussian): _description_ diff --git a/ott/tools/gaussian_mixture/scale_tril.py b/ott/tools/gaussian_mixture/scale_tril.py index 3a4683405..d1db0b96b 100644 --- a/ott/tools/gaussian_mixture/scale_tril.py +++ b/ott/tools/gaussian_mixture/scale_tril.py @@ -55,7 +55,6 @@ def from_random( n_dimensions: number of dimensions stdev: desired standard deviation (around 0) for the log eigenvalues dtype: data type for the covariance matrix - Returns: A ScaleTriL. """ @@ -141,7 +140,6 @@ def w2_dist(self, other: 'ScaleTriL') -> jnp.ndarray: Args: other: Scale for the other Gaussian - Returns: The W_2^2 distance """ @@ -158,7 +156,8 @@ def _flatten_cov(cov: jnp.ndarray) -> jnp.ndarray: cost_fn.pairwise(x0, x1))[...,] def transport_scale_matrix(self, dest_scale: 'ScaleTriL') -> jnp.ndarray: - """ Scaling matrix used in transport between 0-mean normal, mu, w/ current scale to one w/ dest_scale, nu. + """Scaling matrix used in transport between 0-mean normalmu, w/ current scale to one w/ dest_scale, nu/. + m = Sigma_mu ^{-1/2} [ Sigma_mu ^{1/2} Sigma_nu Sigma_mu ^{1/2}] ^{1/2}Sigma_mu ^{-1/2}. Args: @@ -183,11 +182,9 @@ def transport( Args: dest_scale: destination Scale points: points to transport - Returns: Points transported to a Gaussian with the new scale. """ - m = self.transport_scale_matrix(dest_scale) return jnp.transpose(jnp.matmul(m, jnp.transpose(points))) From 17a8db96bf6011b03d5a8a885d285834c428d8b6 Mon Sep 17 00:00:00 2001 From: James Thornton Date: Mon, 4 Jul 2022 02:48:15 -0700 Subject: [PATCH 21/46] fix typo --- ott/core/__init__.py | 1 - ott/core/sinkhorn.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/ott/core/__init__.py b/ott/core/__init__.py index 67ef6898e..ec924ead4 100644 --- a/ott/core/__init__.py +++ b/ott/core/__init__.py @@ -32,7 +32,6 @@ # from . import neuraldual from .implicit_differentiation import ImplicitDiff -from .initializers import SinkhornInitializer from .linear_problems import LinearProblem from .sinkhorn import Sinkhorn from .sinkhorn_lr import LRSinkhorn diff --git a/ott/core/sinkhorn.py b/ott/core/sinkhorn.py index fe08f1bde..cc508dc3e 100644 --- a/ott/core/sinkhorn.py +++ b/ott/core/sinkhorn.py @@ -404,7 +404,7 @@ def __call__( init: Optional[Tuple[Optional[jnp.ndarray], Optional[jnp.ndarray]]] = None ) -> SinkhornOutput: """Main interface to run sinkhorn.""" # noqa: D401 - # initializationgit s + # initialization init_dual_a, init_dual_b = (init if init is not None else (None, None)) if init_dual_a is None: From 92924bffd44b0f08d4d69188efa8cb93d7d863f6 Mon Sep 17 00:00:00 2001 From: James Thornton Date: Mon, 4 Jul 2022 06:32:24 -0700 Subject: [PATCH 22/46] fix stop gradient in Gaussian to include weights and x,y --- ott/core/initializers.py | 19 ++++++++++++------- ott/tools/gaussian_mixture/gaussian.py | 2 +- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/ott/core/initializers.py b/ott/core/initializers.py index 67bc76638..ff736cae6 100644 --- a/ott/core/initializers.py +++ b/ott/core/initializers.py @@ -129,21 +129,26 @@ def init_dual_a( Returns: jnp.ndarray: jnp.ndarray: potential f, array of size n. """ - # import here due to circular imports - from ott.tools.gaussian_mixture.gaussian import Gaussian - cost_matrix = ot_problem.geom.cost_matrix - if self.stop_gradient: - cost_matrix = jax.lax.stop_gradient(cost_matrix) if not isinstance(ot_problem.geom, PointCloud): # warning that init not applied return self.default_dual_a(ot_problem, lse_mode) else: + # import here due to circular imports + from ott.tools.gaussian_mixture.gaussian import Gaussian x = ot_problem.geom.x y = ot_problem.geom.y - gaussian_a = Gaussian.from_samples(x, weights=ot_problem.a) - gaussian_b = Gaussian.from_samples(y, weights=ot_problem.b) + a = ot_problem.a + b = ot_problem.b + if self.stop_gradient: + x = jax.lax.stop_gradient(x) + y = jax.lax.stop_gradient(y) + a = jax.lax.stop_gradient(a) + b = jax.lax.stop_gradient(b) + + gaussian_a = Gaussian.from_samples(x, weights=a) + gaussian_b = Gaussian.from_samples(y, weights=b) # Brenier potential for ground cost ||x-y||^2/2, so multiple by two for cost ||x-y||^2 f_potential = 2 * gaussian_a.f_potential(dest=gaussian_b, points=x) f_potential = f_potential - jnp.mean(f_potential) diff --git a/ott/tools/gaussian_mixture/gaussian.py b/ott/tools/gaussian_mixture/gaussian.py index c0c9fbd46..ba68235a1 100644 --- a/ott/tools/gaussian_mixture/gaussian.py +++ b/ott/tools/gaussian_mixture/gaussian.py @@ -155,7 +155,7 @@ def w2_dist(self, other: 'Gaussian') -> jnp.ndarray: return delta_mean + delta_sigma def f_potential(self, dest: 'Gaussian', points: jnp.ndarray) -> jnp.ndarray: - """W2 distance between Gaussians. + """Dual a potential for W2 distance between Gaussians. Args: dest (Gaussian): _description_ From 077219b07eeea2381377de5dc3f8632f14d23d1a Mon Sep 17 00:00:00 2001 From: James Thornton Date: Mon, 4 Jul 2022 06:36:31 -0700 Subject: [PATCH 23/46] fix stop gradient in Gaussian to include weights and x,y --- ott/core/initializers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ott/core/initializers.py b/ott/core/initializers.py index ff736cae6..5e5da4d33 100644 --- a/ott/core/initializers.py +++ b/ott/core/initializers.py @@ -130,7 +130,6 @@ def init_dual_a( jnp.ndarray: jnp.ndarray: potential f, array of size n. """ - if not isinstance(ot_problem.geom, PointCloud): # warning that init not applied return self.default_dual_a(ot_problem, lse_mode) From e71f6b44bf2cb7fa63bf3397aa461bcfc4297488 Mon Sep 17 00:00:00 2001 From: James Thornton Date: Mon, 4 Jul 2022 06:39:03 -0700 Subject: [PATCH 24/46] fix docstring spaces --- ott/core/initializers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ott/core/initializers.py b/ott/core/initializers.py index 5e5da4d33..ecebc2276 100644 --- a/ott/core/initializers.py +++ b/ott/core/initializers.py @@ -129,7 +129,6 @@ def init_dual_a( Returns: jnp.ndarray: jnp.ndarray: potential f, array of size n. """ - if not isinstance(ot_problem.geom, PointCloud): # warning that init not applied return self.default_dual_a(ot_problem, lse_mode) From 4c1f0b30dbfcfa0f4ebab92bcfb47b261f0a680e Mon Sep 17 00:00:00 2001 From: James Thornton Date: Tue, 5 Jul 2022 13:55:00 -0700 Subject: [PATCH 25/46] feedback from initial review --- ott/core/initializers.py | 260 +++++++++++++---------- ott/core/sinkhorn.py | 5 +- ott/tools/gaussian_mixture/gaussian.py | 27 ++- ott/tools/gaussian_mixture/scale_tril.py | 7 +- tests/core/initializers_test.py | 16 +- tests/core/sinkhorn_test.py | 5 +- 6 files changed, 178 insertions(+), 142 deletions(-) diff --git a/ott/core/initializers.py b/ott/core/initializers.py index ecebc2276..245724d70 100644 --- a/ott/core/initializers.py +++ b/ott/core/initializers.py @@ -12,142 +12,163 @@ # See the License for the specific language governing permissions and # limitations under the License. """Sinkhorn initializers.""" -from typing import Optional, Tuple +from typing import Tuple import jax -from jax import numpy as jnp +import jax.numpy as jnp -from ott.core.linear_problems import LinearProblem +from ott.core import linear_problems from ott.geometry.pointcloud import PointCloud -class SinkhornInitializer(): +def default_dual_a( + ot_problem: linear_problems.LinearProblem, lse_mode: bool +) -> jnp.ndarray: + """Return dual potential vector, f. - def init_dual_a( - self, ot_problem: LinearProblem, lse_mode: bool = True - ) -> jnp.ndarray: - """Initialzation for Sinkhorn potential f. + Args: + ot_problem: + lse_mode: Return log potentials if true. Defaults to True. - Args: - ot_problem (LinearProblem): OT problem between discrete distributions of size n and m. - lse_mode (bool, optional): Return log potential. Defaults to True. - Returns: - jnp.ndarray: dual potential, array of size n - """ - return self.default_dual_a(ot_problem=ot_problem, lse_mode=lse_mode) + Returns: + potential f, array of size n + """ + a = ot_problem.a + init_dual_a = jnp.zeros_like(a) if lse_mode else jnp.ones_like(a) + return init_dual_a - def init_dual_b( - self, ot_problem: LinearProblem, lse_mode: bool = True - ) -> jnp.ndarray: - """Initialzation for Sinkhorn potential g. - Args: - ot_problem (LinearProblem): OT problem between discrete distributions of size n and m. - lse_mode (bool, optional): Return log potential. Defaults to True. - Returns: - jnp.ndarray: dual potential, array of size m - """ - return self.default_dual_b(ot_problem=ot_problem, lse_mode=lse_mode) +def default_dual_b( + ot_problem: linear_problems.LinearProblem, lse_mode: bool +) -> jnp.ndarray: + """Return dual potential vector, g. - def remove_null_weight_potentials( - self, - ot_problem: LinearProblem, - init_dual_a: jnp.ndarray, - init_dual_b: jnp.ndarray, - lse_mode: bool = True - ) -> Tuple[jnp.ndarray]: - """Cancel dual variables for zero weights. + Args: + ot_problem: + lse_mode: Return log potentials if true. Defaults to True. + + Returns: + potential g, array of size m + """ + b = ot_problem.b + init_dual_b = jnp.zeros_like(b) if lse_mode else jnp.ones_like(b) + return init_dual_b + + +def remove_weight_potential( + weights: jnp.ndarray, init_dual: jnp.ndarray, lse_mode: bool +) -> Tuple[jnp.ndarray]: + """Cancel dual variables for zero weights. + + Args: + weights: array of probability masses + init_dual: dual potential array + lse_mode: Return log potentials if true. Defaults to True. + + Returns: + potential + """ + return jnp.where(weights > 0, init_dual, -jnp.inf if lse_mode else 0.0) + + +def remove_weight_potentials( + weights_a: jnp.ndarray, weights_b: jnp.ndarray, init_dual_a: jnp.ndarray, + init_dual_b: jnp.ndarray, lse_mode: bool +) -> Tuple[jnp.ndarray]: + """Cancel dual variables for zero weights. + + Args: + weights_a: array of probability masses, array of size n + weights_b: array of probability masses, array of size m + init_dual_a: potential f, array of size n + init_dual_b: potential g, array of size m + lse_mode: Return log potentials if true. Defaults to True. + + Returns: + potentials (f,g) + """ + init_dual_a = remove_weight_potential(weights_a, init_dual_a, lse_mode) + init_dual_b = remove_weight_potential(weights_b, init_dual_b, lse_mode) + return init_dual_a, init_dual_b - Args: - ot_problem (LinearProblem): - init_dual_a (jnp.ndarray): potential f, array of size n - init_dual_b (jnp.ndarray): potential g, array of size m - lse_mode (bool, optional): Return log potentials if true. Defaults to True. - Returns: - Union[jnp.ndarray]: potentials (f,g) - """ - a, b = ot_problem.a, ot_problem.b - init_dual_a = jnp.where(a > 0, init_dual_a, -jnp.inf if lse_mode else 0.0) - init_dual_b = jnp.where(b > 0, init_dual_b, -jnp.inf if lse_mode else 0.0) - return init_dual_a, init_dual_b - def default_dual_a( - self, ot_problem: LinearProblem, lse_mode: bool = True +class SinkhornInitializer: + """Initialzation for Sinkhorn potential f. + + Args: + ot_problem: OT problem between discrete distributions of size n and m. + lse_mode: Return log potential. Defaults to True. + + Returns: + dual potential, array of size n + """ + + def init_dual_a( + self, ot_problem: linear_problems.LinearProblem, lse_mode: bool ) -> jnp.ndarray: - """Return array of size n, with entries 0 is lse_mode is true, otherwise entries of 1s. - Args: - ot_problem (LinearProblem): - lse_mode (bool, optional): Return log potentials if true. Defaults to True. - Returns: - jnp.ndarray: potential f, array of size n - """ - a = ot_problem.a - init_dual_a = jnp.zeros_like(a) if lse_mode else jnp.ones_like(a) - return init_dual_a + return default_dual_a(ot_problem=ot_problem, lse_mode=lse_mode) - def default_dual_b( - self, ot_problem: LinearProblem, lse_mode: bool = True + def init_dual_b( + self, ot_problem: linear_problems.LinearProblem, lse_mode: bool ) -> jnp.ndarray: - """Return array of size m, with entries 0 is lse_mode is true, otherwise entries of 1s. + """Initialzation for Sinkhorn potential g. Args: - ot_problem (LinearProblem): - lse_mode (bool, optional): Return log potentials if true. Defaults to True. + ot_problem: OT problem between discrete distributions of size n and m. + lse_mode: Return log potential. Defaults to True. + Returns: - jnp.ndarray: potential fg array of size m + dual potential, array of size m """ - b = ot_problem.b - init_dual_b = jnp.zeros_like(b) if lse_mode else jnp.ones_like(b) - return init_dual_b + return default_dual_b(ot_problem=ot_problem, lse_mode=lse_mode) class GaussianInitializer(SinkhornInitializer): + """GaussianInitializer. - def __init__(self, stop_gradient: Optional[bool] = True) -> None: - """GaussianInitializer. + Args: + stop_gradient: Defaults to True. + """ + + def __init__(self, stop_gradient: bool = True) -> None: - Args: - stop_gradient (bool, optional): _description_. Defaults to True. - """ super().__init__() self.stop_gradient = stop_gradient def init_dual_a( self, - ot_problem: LinearProblem, - init_f: Optional[jnp.ndarray] = None, - lse_mode: bool = True + ot_problem: linear_problems.LinearProblem, + lse_mode: bool, ) -> jnp.ndarray: """Gaussian init function. Args: - ot_problem (LinearProblem): OT problem description with geometry and weights. - init_f (Optional[jnp.ndarray], optional): Pre dual sort initialization, when none sets entries as 0. - lse_mode (bool, optional): Return log potential if true. Defaults to True. + ot_problem: OT problem description with geometry and weights. + init_f: Pre dual sort initialization, when none sets entries as 0. + lse_mode: Return log potential if true. Defaults to True. + Returns: - jnp.ndarray: jnp.ndarray: potential f, array of size n. + potential f, array of size n. """ + # import here due to circular imports + from ott.tools.gaussian_mixture.gaussian import Gaussian + if not isinstance(ot_problem.geom, PointCloud): # warning that init not applied - return self.default_dual_a(ot_problem, lse_mode) + return default_dual_a(ot_problem, lse_mode) else: - # import here due to circular imports - from ott.tools.gaussian_mixture.gaussian import Gaussian - x = ot_problem.geom.x - y = ot_problem.geom.y - a = ot_problem.a - b = ot_problem.b + + x, y = ot_problem.geom.x, ot_problem.geom.y + a, b = ot_problem.a, ot_problem.b if self.stop_gradient: - x = jax.lax.stop_gradient(x) - y = jax.lax.stop_gradient(y) - a = jax.lax.stop_gradient(a) - b = jax.lax.stop_gradient(b) + x, y = jax.lax.stop_gradient(x), jax.lax.stop_gradient(y) + a, b = jax.lax.stop_gradient(a), jax.lax.stop_gradient(b) gaussian_a = Gaussian.from_samples(x, weights=a) gaussian_b = Gaussian.from_samples(y, weights=b) - # Brenier potential for ground cost ||x-y||^2/2, so multiple by two for cost ||x-y||^2 + # Brenier potential for cost ||x-y||^2/2, multiply by two for ||x-y||^2 f_potential = 2 * gaussian_a.f_potential(dest=gaussian_b, points=x) f_potential = f_potential - jnp.mean(f_potential) f_potential = f_potential if lse_mode else ot_problem.scaling_from_potential( @@ -157,22 +178,23 @@ def init_dual_a( class SortingInit(SinkhornInitializer): + """Sorting Init class. + + Args: + vector_min: Use vectorized inner loop if true. Defaults to False. + tol: DualSort convergence threshold. Defaults to 1e-2. + max_iter: Max DualSort steps. Defaults to 100. + stop_gradient: Do not trace gradient. Defaults to True. + """ def __init__( self, - vector_min: Optional[bool] = False, - tol: Optional[float] = 1e-2, - max_iter: Optional[int] = 100, - stop_gradient: Optional[bool] = True + vector_min: bool = False, + tol: float = 1e-2, + max_iter: int = 100, + stop_gradient: bool = True ) -> None: - """Sorting Init class. - Args: - vector_min (Optional[bool], optional): Use vectorized inner loop if true. Defaults to False. - tol (Optional[float], optional): DualSort convergence threshold. Defaults to 1e-2. - max_iter (Optional[int], optional): Max DualSort steps. Defaults to 100. - stop_gradient (Optional[bool], optional): Do not trace gradient through the initializer. Defaults to True. - """ super().__init__() self.tolerance = tol @@ -186,10 +208,11 @@ def vectorized_update( """Inner loop DualSort Update. Args: - f (jnp.ndarray): potential f, array of size n - modified_cost (jnp.ndarray): cost matrix minus diagonal of cost matrix across each column + f : potential f, array of size n. + modified_cost: cost matrix minus diagonal column-wise. + Returns: - jnp.ndarray: updated potential vector, f + updated potential vector, f. """ f = jnp.min(modified_cost + f[None, :], axis=1) return f @@ -200,10 +223,11 @@ def coordinate_update( """Coordinate-wise updates within inner loop. Args: - f (jnp.ndarray): potential f, array of size n - modified_cost (jnp.ndarray): cost matrix minus diagonal of cost matrix across each column + f: potential f, array of size n. + modified_cost: cost matrix minus diagonal column-wise. + Returns: - jnp.ndarray: updated potential vector, f + updated potential vector, f. """ def body_fn(i, f): @@ -219,10 +243,11 @@ def init_sorting_dual( """Run DualSort algorithm. Args: - modified_cost (jnp.ndarray): cost matrix minus diagonal of cost matrix across each column - f_potential (jnp.ndarray): potential f, array of size n + modified_cost: cost matrix minus diagonal column-wise. + f_potential: potential f, array of size n. + Returns: - jnp.ndarray: potential f, array of size n + potential f, array of size n. """ it = 0 diff = self.tolerance + 1.0 @@ -238,7 +263,7 @@ def body_fn(state): def cond_fn(state): _, diff, it = state - return (diff > self.tolerance) & (it < self.max_iter) + return jnp.logical_and(diff > self.tolerance, it < self.max_iter) f_potential, _, it = jax.lax.while_loop( cond_fun=cond_fn, body_fun=body_fn, init_val=state @@ -248,17 +273,18 @@ def cond_fn(state): def init_dual_a( self, - ot_problem: LinearProblem, + ot_problem: linear_problems.LinearProblem, + lse_mode: bool, init_f: jnp.ndarray = None, - lse_mode: bool = True ) -> jnp.ndarray: """Apply DualSort algo. Args: - ot_problem (LinearProblem): OT problem - init_f (jnp.ndarray, optional): potential f, array of size n. Defaults to None. + ot_problem: OT problem. + init_f: potential f, array of size n. Defaults to None. + Returns: - jnp.ndarray: potential f, array of size n + potential f, array of size n. """ cost_matrix = ot_problem.geom.cost_matrix if self.stop_gradient: diff --git a/ott/core/sinkhorn.py b/ott/core/sinkhorn.py index cc508dc3e..9b29c5be5 100644 --- a/ott/core/sinkhorn.py +++ b/ott/core/sinkhorn.py @@ -418,8 +418,9 @@ def __call__( ) # Cancel dual variables for zero weights. - init_dual_a, init_dual_b = self.potential_initializer.remove_null_weight_potentials( - ot_problem=ot_prob, + init_dual_a, init_dual_b = init_lib.remove_weight_potentials( + weights_a=ot_prob.a, + weights_b=ot_prob.b, init_dual_a=init_dual_a, init_dual_b=init_dual_b, lse_mode=self.lse_mode diff --git a/ott/tools/gaussian_mixture/gaussian.py b/ott/tools/gaussian_mixture/gaussian.py index ba68235a1..a83fe9f3f 100644 --- a/ott/tools/gaussian_mixture/gaussian.py +++ b/ott/tools/gaussian_mixture/gaussian.py @@ -39,24 +39,24 @@ def __init__(self, loc: jnp.ndarray, scale: scale_tril.ScaleTriL): @classmethod def from_samples( - cls, x: jnp.ndarray, weights: jnp.ndarray = None + cls, points: jnp.ndarray, weights: jnp.ndarray = None ) -> 'Gaussian': """Construct a Gaussian from weighted samples. Args: - x: [n x d] array of samples + points: [n x d] array of samples weights: [n] array of weights Returns: Gaussian. """ if weights is None: - n = x.shape[0] + n = points.shape[0] weights = jnp.ones(n) / n - mean = weights.dot(x) - scaled_centered_x = (x - mean) * weights.reshape(-1, 1) - cov = (scaled_centered_x).T.dot(scaled_centered_x) / weights.T.dot(weights) + mean = weights.dot(points) + scaled_centered_x = (points - mean) * weights.reshape(-1, 1) + cov = scaled_centered_x.T.dot(scaled_centered_x) / weights.T.dot(weights) return cls.from_mean_and_cov(mean=mean, cov=cov) @classmethod @@ -158,11 +158,11 @@ def f_potential(self, dest: 'Gaussian', points: jnp.ndarray) -> jnp.ndarray: """Dual a potential for W2 distance between Gaussians. Args: - dest (Gaussian): _description_ - points (jnp.ndarray): _description_ + dest: Gaussian object + points: samples Returns: - jnp.ndarray: _description_ + Dual potential, f """ scale_matrix = self.scale.transport_scale_matrix(dest_scale=dest.scale) centered_x = points - self.loc @@ -171,19 +171,18 @@ def f_potential(self, dest: 'Gaussian', points: jnp.ndarray) -> jnp.ndarray: ) return ( 0.5 * batch_inner_product(points, points) - - 0.5 * batch_inner_product(centered_x, scaled_x) - - (points).dot(dest.loc) + 0.5 * batch_inner_product(centered_x, scaled_x) - points.dot(dest.loc) ) def transport(self, dest: 'Gaussian', points: jnp.ndarray) -> jnp.ndarray: """Transport Gaussian. Args: - dest (Gaussian): _description_ - points (jnp.ndarray): _description_ + dest: Gaussian object + points: samples Returns: - jnp.ndarray: _description_ + Transported samples """ return self.scale.transport( dest_scale=dest.scale, points=points - self.loc[None] diff --git a/ott/tools/gaussian_mixture/scale_tril.py b/ott/tools/gaussian_mixture/scale_tril.py index d1db0b96b..b31a19899 100644 --- a/ott/tools/gaussian_mixture/scale_tril.py +++ b/ott/tools/gaussian_mixture/scale_tril.py @@ -156,9 +156,11 @@ def _flatten_cov(cov: jnp.ndarray) -> jnp.ndarray: cost_fn.pairwise(x0, x1))[...,] def transport_scale_matrix(self, dest_scale: 'ScaleTriL') -> jnp.ndarray: - """Scaling matrix used in transport between 0-mean normalmu, w/ current scale to one w/ dest_scale, nu/. + """Scaling matrix used in transport between 0-mean Gaussians. - m = Sigma_mu ^{-1/2} [ Sigma_mu ^{1/2} Sigma_nu Sigma_mu ^{1/2}] ^{1/2}Sigma_mu ^{-1/2}. + Sigma_mu^{-1/2} @ + [Sigma_mu ^{1/2} Sigma_nu Sigma_mu ^{1/2}]^{1/2} + @ Sigma_mu ^{-1/2} Args: dest_scale: destination Scale @@ -182,6 +184,7 @@ def transport( Args: dest_scale: destination Scale points: points to transport + Returns: Points transported to a Gaussian with the new scale. """ diff --git a/tests/core/initializers_test.py b/tests/core/initializers_test.py index 47bbcfa44..d2243f5f6 100644 --- a/tests/core/initializers_test.py +++ b/tests/core/initializers_test.py @@ -114,8 +114,12 @@ def test_default_initializer(self): geom = PointCloud(x_jnp, y_jnp, **geom_kwargs) ot_problem = LinearProblem(geom=geom, a=a, b=b) - default_potential_a = gaus_init.default_dual_a(ot_problem=ot_problem) - default_potential_b = gaus_init.default_dual_b(ot_problem=ot_problem) + default_potential_a = init_lib.default_dual_a( + ot_problem=ot_problem, lse_mode=True + ) + default_potential_b = init_lib.default_dual_b( + ot_problem=ot_problem, lse_mode=True + ) # check default is 0 self.assertTrue((jnp.zeros(n) == default_potential_a).all()) @@ -124,8 +128,12 @@ def test_default_initializer(self): # check gausian init returns 0 for non point cloud geometry new_geom = Geometry(cost_matrix=geom.cost_matrix, **geom_kwargs) ot_problem = LinearProblem(geom=new_geom, a=a, b=b) - init_potential_a = gaus_init.init_dual_a(ot_problem=ot_problem) - init_potential_b = gaus_init.init_dual_a(ot_problem=ot_problem) + init_potential_a = gaus_init.init_dual_a( + ot_problem=ot_problem, lse_mode=True + ) + init_potential_b = gaus_init.init_dual_b( + ot_problem=ot_problem, lse_mode=True + ) self.assertTrue((jnp.zeros(n) == init_potential_a).all()) self.assertTrue((jnp.zeros(m) == init_potential_b).all()) diff --git a/tests/core/sinkhorn_test.py b/tests/core/sinkhorn_test.py index 1ebac575c..0470547bb 100644 --- a/tests/core/sinkhorn_test.py +++ b/tests/core/sinkhorn_test.py @@ -469,14 +469,13 @@ def test_restart(self, lse_mode): num_iter_restarted = jnp.sum(errors_restarted > -1) # check we can only improve on error - # num_iter = jnp.sum(errors>-1) - # self.assertGreater(num_iter, num_iter_restarted) + num_iter = jnp.sum(errors > -1) + self.assertGreater(num_iter, num_iter_restarted) # check we can only improve on error self.assertGreater(err + threshold, err_restarted) # # check first error in restart does at least as well as previous best - self.assertGreater(err + threshold, errors_restarted[2]) self.assertGreater(err + threshold, errors_restarted[0]) # check only one iteration suffices when restarting with same data. From 99d0bd18e5ca46f3cddb91bb3b169e2e50e69e36 Mon Sep 17 00:00:00 2001 From: James Thornton Date: Tue, 5 Jul 2022 14:04:17 -0700 Subject: [PATCH 26/46] re order local functions before state init --- ott/core/initializers.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ott/core/initializers.py b/ott/core/initializers.py index 245724d70..bd2ffa1da 100644 --- a/ott/core/initializers.py +++ b/ott/core/initializers.py @@ -249,10 +249,6 @@ def init_sorting_dual( Returns: potential f, array of size n. """ - it = 0 - diff = self.tolerance + 1.0 - - state = (f_potential, diff, it) def body_fn(state): prev_f, _, it = state @@ -265,6 +261,10 @@ def cond_fn(state): _, diff, it = state return jnp.logical_and(diff > self.tolerance, it < self.max_iter) + it = 0 + diff = self.tolerance + 1.0 + state = (f_potential, diff, it) + f_potential, _, it = jax.lax.while_loop( cond_fun=cond_fn, body_fun=body_fn, init_val=state ) From 60b973b2ebe65eff53de2505dfccc7caa3127673 Mon Sep 17 00:00:00 2001 From: James Thornton Date: Tue, 5 Jul 2022 17:03:53 -0700 Subject: [PATCH 27/46] optional init_f in sorting init --- ott/core/initializers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ott/core/initializers.py b/ott/core/initializers.py index bd2ffa1da..c3cb1f5cf 100644 --- a/ott/core/initializers.py +++ b/ott/core/initializers.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Sinkhorn initializers.""" -from typing import Tuple +from typing import Tuple, Optional import jax import jax.numpy as jnp @@ -275,7 +275,7 @@ def init_dual_a( self, ot_problem: linear_problems.LinearProblem, lse_mode: bool, - init_f: jnp.ndarray = None, + init_f: Optional[jnp.ndarray] = None, ) -> jnp.ndarray: """Apply DualSort algo. From 3e2df88f0b93c0929849974f1155abefbf6b6f28 Mon Sep 17 00:00:00 2001 From: James Thornton Date: Tue, 5 Jul 2022 17:07:29 -0700 Subject: [PATCH 28/46] docstring insert line before return --- ott/tools/gaussian_mixture/scale_tril.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ott/tools/gaussian_mixture/scale_tril.py b/ott/tools/gaussian_mixture/scale_tril.py index b31a19899..7ea2d32fe 100644 --- a/ott/tools/gaussian_mixture/scale_tril.py +++ b/ott/tools/gaussian_mixture/scale_tril.py @@ -55,6 +55,7 @@ def from_random( n_dimensions: number of dimensions stdev: desired standard deviation (around 0) for the log eigenvalues dtype: data type for the covariance matrix + Returns: A ScaleTriL. """ @@ -140,6 +141,7 @@ def w2_dist(self, other: 'ScaleTriL') -> jnp.ndarray: Args: other: Scale for the other Gaussian + Returns: The W_2^2 distance """ From 14f3b64e02d025c04b749d2fa7d37433bad60a63 Mon Sep 17 00:00:00 2001 From: James Thornton Date: Tue, 5 Jul 2022 17:12:15 -0700 Subject: [PATCH 29/46] lint fix --- ott/core/initializers.py | 2 +- ott/tools/gaussian_mixture/scale_tril.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ott/core/initializers.py b/ott/core/initializers.py index c3cb1f5cf..992984392 100644 --- a/ott/core/initializers.py +++ b/ott/core/initializers.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Sinkhorn initializers.""" -from typing import Tuple, Optional +from typing import Optional, Tuple import jax import jax.numpy as jnp diff --git a/ott/tools/gaussian_mixture/scale_tril.py b/ott/tools/gaussian_mixture/scale_tril.py index 7ea2d32fe..aa8f8758f 100644 --- a/ott/tools/gaussian_mixture/scale_tril.py +++ b/ott/tools/gaussian_mixture/scale_tril.py @@ -55,7 +55,7 @@ def from_random( n_dimensions: number of dimensions stdev: desired standard deviation (around 0) for the log eigenvalues dtype: data type for the covariance matrix - + Returns: A ScaleTriL. """ From 5d1f648b1f8915fe7e60ee5e8fdfb928a94bedcf Mon Sep 17 00:00:00 2001 From: James Thornton Date: Tue, 12 Jul 2022 14:07:41 -0700 Subject: [PATCH 30/46] incorporate feedback in commit --- ott/core/initializers.py | 58 ++++++++++++++++-------- ott/tools/gaussian_mixture/gaussian.py | 37 ++++++++------- ott/tools/gaussian_mixture/scale_tril.py | 10 ++-- ott/tools/transport.py | 7 ++- 4 files changed, 68 insertions(+), 44 deletions(-) diff --git a/ott/core/initializers.py b/ott/core/initializers.py index 992984392..e5522d082 100644 --- a/ott/core/initializers.py +++ b/ott/core/initializers.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. """Sinkhorn initializers.""" +from pickle import TRUE from typing import Optional, Tuple import jax import jax.numpy as jnp from ott.core import linear_problems -from ott.geometry.pointcloud import PointCloud - +from ott.geometry import pointcloud def default_dual_a( ot_problem: linear_problems.LinearProblem, lse_mode: bool @@ -28,7 +28,7 @@ def default_dual_a( Args: ot_problem: - lse_mode: Return log potentials if true. Defaults to True. + lse_mode: Return potentials if true, scaling if false. Returns: potential f, array of size n @@ -45,7 +45,7 @@ def default_dual_b( Args: ot_problem: - lse_mode: Return log potentials if true. Defaults to True. + lse_mode: Return potentials if true, scaling if false. Returns: potential g, array of size m @@ -63,8 +63,7 @@ def remove_weight_potential( Args: weights: array of probability masses init_dual: dual potential array - lse_mode: Return log potentials if true. Defaults to True. - + lse_mode: Return potentials if true, scaling if false. Returns: potential """ @@ -82,7 +81,7 @@ def remove_weight_potentials( weights_b: array of probability masses, array of size m init_dual_a: potential f, array of size n init_dual_b: potential g, array of size m - lse_mode: Return log potentials if true. Defaults to True. + lse_mode: Return potentials if true, scaling if false. Returns: potentials (f,g) @@ -93,11 +92,11 @@ def remove_weight_potentials( class SinkhornInitializer: - """Initialzation for Sinkhorn potential f. + """Initialization. Args: ot_problem: OT problem between discrete distributions of size n and m. - lse_mode: Return log potential. Defaults to True. + lse_mode: Return potential if true, scaling if false. Returns: dual potential, array of size n @@ -106,6 +105,15 @@ class SinkhornInitializer: def init_dual_a( self, ot_problem: linear_problems.LinearProblem, lse_mode: bool ) -> jnp.ndarray: + """Initialzation for Sinkhorn potential f. + + Args: + ot_problem: OT problem between discrete distributions of size n and m. + lse_mode: Return potential if true, scaling if false. + + Returns: + dual potential, array of size n + """ return default_dual_a(ot_problem=ot_problem, lse_mode=lse_mode) @@ -116,7 +124,7 @@ def init_dual_b( Args: ot_problem: OT problem between discrete distributions of size n and m. - lse_mode: Return log potential. Defaults to True. + lse_mode: Return potential if true, scaling if false. Returns: dual potential, array of size m @@ -126,6 +134,12 @@ def init_dual_b( class GaussianInitializer(SinkhornInitializer): """GaussianInitializer. + + From https://arxiv.org/abs/2206.07630. + Compute Gaussian approximations of each pointcloud, then compute closed from + Kantorovic potential betwen Gaussian approximations using Brenier's theorem + (adapt convex/ Brenier potential to Kantoroic). Use this Gaussian potential to + initialize Sinkhorn potentials. Args: stop_gradient: Defaults to True. @@ -147,15 +161,15 @@ def init_dual_a( Args: ot_problem: OT problem description with geometry and weights. init_f: Pre dual sort initialization, when none sets entries as 0. - lse_mode: Return log potential if true. Defaults to True. + lse_mode: Return potential if true, scaling if false. Returns: potential f, array of size n. """ - # import here due to circular imports - from ott.tools.gaussian_mixture.gaussian import Gaussian + # import Gaussian here due to circular imports + from ott.tools.gaussian_mixture import gaussian - if not isinstance(ot_problem.geom, PointCloud): + if not isinstance(ot_problem.geom, pointcloud.PointCloud): # warning that init not applied return default_dual_a(ot_problem, lse_mode) else: @@ -166,8 +180,8 @@ def init_dual_a( x, y = jax.lax.stop_gradient(x), jax.lax.stop_gradient(y) a, b = jax.lax.stop_gradient(a), jax.lax.stop_gradient(b) - gaussian_a = Gaussian.from_samples(x, weights=a) - gaussian_b = Gaussian.from_samples(y, weights=b) + gaussian_a = gaussian.Gaussian.from_samples(x, weights=a) + gaussian_b = gaussian.Gaussian.from_samples(y, weights=b) # Brenier potential for cost ||x-y||^2/2, multiply by two for ||x-y||^2 f_potential = 2 * gaussian_a.f_potential(dest=gaussian_b, points=x) f_potential = f_potential - jnp.mean(f_potential) @@ -180,8 +194,13 @@ def init_dual_a( class SortingInit(SinkhornInitializer): """Sorting Init class. + DualSort algorithm from https://arxiv.org/abs/2206.07630, solve + non-regularized OT problem via sorting, then compute potential through + iterated minimum on C-transform and use this potentials to initialize + regularized potential + Args: - vector_min: Use vectorized inner loop if true. Defaults to False. + vector_min: Use vectorized inner loop if true. Defaults to True. tol: DualSort convergence threshold. Defaults to 1e-2. max_iter: Max DualSort steps. Defaults to 100. stop_gradient: Do not trace gradient. Defaults to True. @@ -189,7 +208,7 @@ class SortingInit(SinkhornInitializer): def __init__( self, - vector_min: bool = False, + vector_min: bool = True, tol: float = 1e-2, max_iter: int = 100, stop_gradient: bool = True @@ -281,7 +300,8 @@ def init_dual_a( Args: ot_problem: OT problem. - init_f: potential f, array of size n. Defaults to None. + lse_mode: Return potential if true, scaling if false. + init_f: potential f, array of size n. Returns: potential f, array of size n. diff --git a/ott/tools/gaussian_mixture/gaussian.py b/ott/tools/gaussian_mixture/gaussian.py index a83fe9f3f..16b5e1829 100644 --- a/ott/tools/gaussian_mixture/gaussian.py +++ b/ott/tools/gaussian_mixture/gaussian.py @@ -24,11 +24,6 @@ LOG2PI = math.log(2. * math.pi) -@jax.vmap -def batch_inner_product(x, y): - return x.dot(y) - - @jax.tree_util.register_pytree_node_class class Gaussian: """PyTree for a normal distribution.""" @@ -43,6 +38,9 @@ def from_samples( ) -> 'Gaussian': """Construct a Gaussian from weighted samples. + Unbiased, weighted covariance formular from https://en.wikipedia.org/wiki/Sample_mean_and_covariance#Weighted_samples + and https://www.gnu.org/software/gsl/doc/html/statistics.html?highlight=weighted#weighted-samples + Args: points: [n x d] array of samples weights: [n] array of weights @@ -50,13 +48,14 @@ def from_samples( Returns: Gaussian. """ + n = points.shape[0] if weights is None: - n = points.shape[0] - weights = jnp.ones(n) / n - + weights = jnp.ones(n) / n + mean = weights.dot(points) - scaled_centered_x = (points - mean) * weights.reshape(-1, 1) - cov = scaled_centered_x.T.dot(scaled_centered_x) / weights.T.dot(weights) + centered_x = (points - mean) + scaled_centered_x = centered_x * weights.reshape(-1, 1) + cov = scaled_centered_x.T.dot(centered_x) / (1-weights.dot(weights)) return cls.from_mean_and_cov(mean=mean, cov=cov) @classmethod @@ -155,7 +154,7 @@ def w2_dist(self, other: 'Gaussian') -> jnp.ndarray: return delta_mean + delta_sigma def f_potential(self, dest: 'Gaussian', points: jnp.ndarray) -> jnp.ndarray: - """Dual a potential for W2 distance between Gaussians. + """Evaluate optimal dual potential for W2 distance between Gaussians. Args: dest: Gaussian object @@ -166,16 +165,22 @@ def f_potential(self, dest: 'Gaussian', points: jnp.ndarray) -> jnp.ndarray: """ scale_matrix = self.scale.transport_scale_matrix(dest_scale=dest.scale) centered_x = points - self.loc - scaled_x = jnp.transpose( - jnp.matmul(scale_matrix, jnp.transpose(centered_x)) - ) + # scaled_x = jnp.transpose( + # jnp.matmul(scale_matrix, jnp.transpose(centered_x)) + # ) + scaled_x = (scale_matrix @ centered_x.T) + + @jax.vmap + def batch_inner_product(x, y): + return x.dot(y) + return ( 0.5 * batch_inner_product(points, points) - - 0.5 * batch_inner_product(centered_x, scaled_x) - points.dot(dest.loc) + 0.5 * batch_inner_product(centered_x, scaled_x.T) - points.dot(dest.loc) ) def transport(self, dest: 'Gaussian', points: jnp.ndarray) -> jnp.ndarray: - """Transport Gaussian. + """Transport points according to map between two Gaussian measures. Args: dest: Gaussian object diff --git a/ott/tools/gaussian_mixture/scale_tril.py b/ott/tools/gaussian_mixture/scale_tril.py index aa8f8758f..e7d983a87 100644 --- a/ott/tools/gaussian_mixture/scale_tril.py +++ b/ott/tools/gaussian_mixture/scale_tril.py @@ -69,7 +69,7 @@ def from_random( ) # random positive definite matrix - sigma = jnp.matmul(jnp.expand_dims(eigs, -2) * q, jnp.transpose(q)) + sigma = q * jnp.expand_dims(eigs, -2) @ q.T # cholesky factorization chol = jnp.linalg.cholesky(sigma) @@ -117,7 +117,7 @@ def cholesky(self) -> jnp.ndarray: def covariance(self) -> jnp.ndarray: """Get the covariance matrix.""" cholesky = self.cholesky() - return jnp.matmul(cholesky, jnp.transpose(cholesky)) + return cholesky @ cholesky.T def covariance_sqrt(self) -> jnp.ndarray: """Get the square root of the covariance matrix.""" @@ -134,7 +134,7 @@ def centered_to_z(self, x_centered: jnp.ndarray) -> jnp.ndarray: def z_to_centered(self, z: jnp.ndarray) -> jnp.ndarray: """Scale standardized points to points with the specified covariance.""" - return jnp.transpose(jnp.matmul(self.cholesky(), jnp.transpose(z))) + return (self.cholesky() @ z.T).T def w2_dist(self, other: 'ScaleTriL') -> jnp.ndarray: r"""Wasserstein distance W_2^2 to another Gaussian with same mean. @@ -181,7 +181,7 @@ def transport_scale_matrix(self, dest_scale: 'ScaleTriL') -> jnp.ndarray: def transport( self, dest_scale: 'ScaleTriL', points: jnp.ndarray ) -> jnp.ndarray: - """Transport between 0-mean normal w/ current scale to one w/ dest_scale. + """Apply Monge map between 0-mean Gaussians. Args: dest_scale: destination Scale @@ -191,7 +191,7 @@ def transport( Points transported to a Gaussian with the new scale. """ m = self.transport_scale_matrix(dest_scale) - return jnp.transpose(jnp.matmul(m, jnp.transpose(points))) + return (m @ points.T).T def tree_flatten(self): children = (self.params,) diff --git a/ott/tools/transport.py b/ott/tools/transport.py index 983ea8a75..7ca519d95 100644 --- a/ott/tools/transport.py +++ b/ott/tools/transport.py @@ -78,6 +78,8 @@ def solve( *args: Any, a: Optional[jnp.ndarray] = None, b: Optional[jnp.ndarray] = None, + init_dual_a: Optional[jnp.ndarray] = None, + init_dual_b: Optional[jnp.ndarray] = None, objective: Optional[Literal['linear', 'quadratic', 'fused']] = None, **kwargs: Any ) -> Transport: @@ -122,11 +124,8 @@ def solve( solver_fn = sinkhorn.make if linear else gromov_wasserstein.make geom_keys = ['cost_fn', 'power', 'online'] - init_dual_a = kwargs.get('init_dual_a', None) - init_dual_b = kwargs.get('init_dual_b', None) - init_keys = ['init_dual_a', 'init_dual_b'] - remove_keys = init_keys + geom_keys + eps_keys if linear else geom_keys + remove_keys = geom_keys + eps_keys if linear else geom_keys for key in remove_keys: kwargs.pop(key, None) solver = solver_fn(**kwargs) From aca73e7b920f566e0d13796069070e99c847fff5 Mon Sep 17 00:00:00 2001 From: James Thornton Date: Tue, 12 Jul 2022 17:54:57 -0700 Subject: [PATCH 31/46] tidy tests, use jax.lax.cond for logic instead of if --- ott/core/initializers.py | 41 +++-- ott/tools/gaussian_mixture/gaussian.py | 20 +- ott/tools/gaussian_mixture/scale_tril.py | 10 +- ott/tools/transport.py | 1 - tests/core/initializers_test.py | 225 ++++++++++++----------- 5 files changed, 158 insertions(+), 139 deletions(-) diff --git a/ott/core/initializers.py b/ott/core/initializers.py index e5522d082..c1bea3aae 100644 --- a/ott/core/initializers.py +++ b/ott/core/initializers.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """Sinkhorn initializers.""" -from pickle import TRUE from typing import Optional, Tuple import jax @@ -21,7 +20,8 @@ from ott.core import linear_problems from ott.geometry import pointcloud -def default_dual_a( + +def _default_dual_a( ot_problem: linear_problems.LinearProblem, lse_mode: bool ) -> jnp.ndarray: """Return dual potential vector, f. @@ -38,7 +38,7 @@ def default_dual_a( return init_dual_a -def default_dual_b( +def _default_dual_b( ot_problem: linear_problems.LinearProblem, lse_mode: bool ) -> jnp.ndarray: """Return dual potential vector, g. @@ -55,7 +55,7 @@ def default_dual_b( return init_dual_b -def remove_weight_potential( +def _remove_single_weight_potential( weights: jnp.ndarray, init_dual: jnp.ndarray, lse_mode: bool ) -> Tuple[jnp.ndarray]: """Cancel dual variables for zero weights. @@ -86,13 +86,17 @@ def remove_weight_potentials( Returns: potentials (f,g) """ - init_dual_a = remove_weight_potential(weights_a, init_dual_a, lse_mode) - init_dual_b = remove_weight_potential(weights_b, init_dual_b, lse_mode) + init_dual_a = _remove_single_weight_potential( + weights_a, init_dual_a, lse_mode + ) + init_dual_b = _remove_single_weight_potential( + weights_b, init_dual_b, lse_mode + ) return init_dual_a, init_dual_b class SinkhornInitializer: - """Initialization. + """Initialization of Sinkhorn dual potentials. Args: ot_problem: OT problem between discrete distributions of size n and m. @@ -114,8 +118,7 @@ def init_dual_a( Returns: dual potential, array of size n """ - - return default_dual_a(ot_problem=ot_problem, lse_mode=lse_mode) + return _default_dual_a(ot_problem=ot_problem, lse_mode=lse_mode) def init_dual_b( self, ot_problem: linear_problems.LinearProblem, lse_mode: bool @@ -129,12 +132,12 @@ def init_dual_b( Returns: dual potential, array of size m """ - return default_dual_b(ot_problem=ot_problem, lse_mode=lse_mode) + return _default_dual_b(ot_problem=ot_problem, lse_mode=lse_mode) class GaussianInitializer(SinkhornInitializer): """GaussianInitializer. - + From https://arxiv.org/abs/2206.07630. Compute Gaussian approximations of each pointcloud, then compute closed from Kantorovic potential betwen Gaussian approximations using Brenier's theorem @@ -161,17 +164,17 @@ def init_dual_a( Args: ot_problem: OT problem description with geometry and weights. init_f: Pre dual sort initialization, when none sets entries as 0. - lse_mode: Return potential if true, scaling if false. + lse_mode: Return potential if true, scaling if false. Returns: potential f, array of size n. """ # import Gaussian here due to circular imports - from ott.tools.gaussian_mixture import gaussian + from ott.tools.gaussian_mixture import gaussian if not isinstance(ot_problem.geom, pointcloud.PointCloud): # warning that init not applied - return default_dual_a(ot_problem, lse_mode) + return _default_dual_a(ot_problem, lse_mode) else: x, y = ot_problem.geom.x, ot_problem.geom.y @@ -194,9 +197,9 @@ def init_dual_a( class SortingInit(SinkhornInitializer): """Sorting Init class. - DualSort algorithm from https://arxiv.org/abs/2206.07630, solve - non-regularized OT problem via sorting, then compute potential through - iterated minimum on C-transform and use this potentials to initialize + DualSort algorithm from https://arxiv.org/abs/2206.07630, solve + non-regularized OT problem via sorting, then compute potential through + iterated minimum on C-transform and use this potentials to initialize regularized potential Args: @@ -219,7 +222,9 @@ def __init__( self.tolerance = tol self.stop_gradient = stop_gradient self.max_iter = max_iter - self.update_fn = self.vectorized_update if vector_min else self.coordinate_update + self.update_fn = lambda f, mod_cost: jax.lax.cond( + vector_min, self.vectorized_update, self.coordinate_update, f, mod_cost + ) def vectorized_update( self, f: jnp.ndarray, modified_cost: jnp.ndarray diff --git a/ott/tools/gaussian_mixture/gaussian.py b/ott/tools/gaussian_mixture/gaussian.py index 16b5e1829..bd7b5ac50 100644 --- a/ott/tools/gaussian_mixture/gaussian.py +++ b/ott/tools/gaussian_mixture/gaussian.py @@ -38,7 +38,7 @@ def from_samples( ) -> 'Gaussian': """Construct a Gaussian from weighted samples. - Unbiased, weighted covariance formular from https://en.wikipedia.org/wiki/Sample_mean_and_covariance#Weighted_samples + Unbiased, weighted covariance formula from https://en.wikipedia.org/wiki/Sample_mean_and_covariance#Weighted_samples and https://www.gnu.org/software/gsl/doc/html/statistics.html?highlight=weighted#weighted-samples Args: @@ -50,12 +50,12 @@ def from_samples( """ n = points.shape[0] if weights is None: - weights = jnp.ones(n) / n - + weights = jnp.ones(n) / n + mean = weights.dot(points) centered_x = (points - mean) scaled_centered_x = centered_x * weights.reshape(-1, 1) - cov = scaled_centered_x.T.dot(centered_x) / (1-weights.dot(weights)) + cov = scaled_centered_x.T.dot(centered_x) / (1 - weights.dot(weights)) return cls.from_mean_and_cov(mean=mean, cov=cov) @classmethod @@ -154,7 +154,7 @@ def w2_dist(self, other: 'Gaussian') -> jnp.ndarray: return delta_mean + delta_sigma def f_potential(self, dest: 'Gaussian', points: jnp.ndarray) -> jnp.ndarray: - """Evaluate optimal dual potential for W2 distance between Gaussians. + """Optimal potential for W2 distance between Gaussians. Evaluated on points. Args: dest: Gaussian object @@ -163,12 +163,9 @@ def f_potential(self, dest: 'Gaussian', points: jnp.ndarray) -> jnp.ndarray: Returns: Dual potential, f """ - scale_matrix = self.scale.transport_scale_matrix(dest_scale=dest.scale) + scale_matrix = self.scale.gaussian_map(dest_scale=dest.scale) centered_x = points - self.loc - # scaled_x = jnp.transpose( - # jnp.matmul(scale_matrix, jnp.transpose(centered_x)) - # ) - scaled_x = (scale_matrix @ centered_x.T) + scaled_x = (scale_matrix @ centered_x.T) @jax.vmap def batch_inner_product(x, y): @@ -176,7 +173,8 @@ def batch_inner_product(x, y): return ( 0.5 * batch_inner_product(points, points) - - 0.5 * batch_inner_product(centered_x, scaled_x.T) - points.dot(dest.loc) + 0.5 * batch_inner_product(centered_x, scaled_x.T) - + points.dot(dest.loc) ) def transport(self, dest: 'Gaussian', points: jnp.ndarray) -> jnp.ndarray: diff --git a/ott/tools/gaussian_mixture/scale_tril.py b/ott/tools/gaussian_mixture/scale_tril.py index e7d983a87..c4f7ea077 100644 --- a/ott/tools/gaussian_mixture/scale_tril.py +++ b/ott/tools/gaussian_mixture/scale_tril.py @@ -69,7 +69,7 @@ def from_random( ) # random positive definite matrix - sigma = q * jnp.expand_dims(eigs, -2) @ q.T + sigma = q * jnp.expand_dims(eigs, -2) @ q.T # cholesky factorization chol = jnp.linalg.cholesky(sigma) @@ -157,7 +157,7 @@ def _flatten_cov(cov: jnp.ndarray) -> jnp.ndarray: return (cost_fn.norm(x0) + cost_fn.norm(x1) + cost_fn.pairwise(x0, x1))[...,] - def transport_scale_matrix(self, dest_scale: 'ScaleTriL') -> jnp.ndarray: + def gaussian_map(self, dest_scale: 'ScaleTriL') -> jnp.ndarray: """Scaling matrix used in transport between 0-mean Gaussians. Sigma_mu^{-1/2} @ @@ -168,7 +168,7 @@ def transport_scale_matrix(self, dest_scale: 'ScaleTriL') -> jnp.ndarray: dest_scale: destination Scale Returns: - Gaussian scaling matrix, same dimension as self.covaraince() + Gaussian scaling matrix, same dimension as self.covaraince """ sqrt0, sqrt0_inv = linalg.matrix_powers(self.covariance(), (0.5, -0.5)) sigma1 = dest_scale.covariance() @@ -181,7 +181,7 @@ def transport_scale_matrix(self, dest_scale: 'ScaleTriL') -> jnp.ndarray: def transport( self, dest_scale: 'ScaleTriL', points: jnp.ndarray ) -> jnp.ndarray: - """Apply Monge map between 0-mean Gaussians. + """Apply Monge map, computed between two 0-mean Gaussians, to points. Args: dest_scale: destination Scale @@ -190,7 +190,7 @@ def transport( Returns: Points transported to a Gaussian with the new scale. """ - m = self.transport_scale_matrix(dest_scale) + m = self.gaussian_map(dest_scale) return (m @ points.T).T def tree_flatten(self): diff --git a/ott/tools/transport.py b/ott/tools/transport.py index 7ca519d95..73674ed78 100644 --- a/ott/tools/transport.py +++ b/ott/tools/transport.py @@ -124,7 +124,6 @@ def solve( solver_fn = sinkhorn.make if linear else gromov_wasserstein.make geom_keys = ['cost_fn', 'power', 'online'] - remove_keys = geom_keys + eps_keys if linear else geom_keys for key in remove_keys: kwargs.pop(key, None) diff --git a/tests/core/initializers_test.py b/tests/core/initializers_test.py index d2243f5f6..8ccda6e7f 100644 --- a/tests/core/initializers_test.py +++ b/tests/core/initializers_test.py @@ -20,10 +20,38 @@ from absl.testing import absltest, parameterized from ott.core import initializers as init_lib -from ott.core.linear_problems import LinearProblem +from ott.core import linear_problems from ott.core.sinkhorn import sinkhorn -from ott.geometry.geometry import Geometry -from ott.geometry.pointcloud import PointCloud +from ott.geometry import geometry, pointcloud + + +# define sinkhorn functions +@jax.jit +def run_sinkhorn_sort_init(x, y, a=None, b=None, epsilon=0.01, vector_min=True): + geom = pointcloud.PointCloud(x, y, epsilon=epsilon) + sort_init = init_lib.SortingInit(vector_min=vector_min) + out = sinkhorn(geom, a=a, b=b, jit=True, potential_initializer=sort_init) + return out + + +@jax.jit +def run_sinkhorn(x, y, a=None, b=None, epsilon=0.01): + geom = pointcloud.PointCloud(x, y, epsilon=epsilon) + out = sinkhorn(geom, a=a, b=b, jit=True) + return out + + +@jax.jit +def run_sinkhorn_gaus_init(x, y, a=None, b=None, epsilon=0.01): + geom = pointcloud.PointCloud(x, y, epsilon=epsilon) + out = sinkhorn( + geom, + a=a, + b=b, + jit=True, + potential_initializer=init_lib.GaussianInitializer() + ) + return out class InitializerTest(parameterized.TestCase): @@ -32,42 +60,17 @@ def setUp(self): super().setUp() self.rng = jax.random.PRNGKey(0) - def test_sorting_init(self): - """Tests sorting dual initializer.""" - - # init initializer - sort_init = init_lib.SortingInit(vector_min=True) - - # define sinkhorn functions - @jax.jit - def run_sinkhorn_sort_init(x, y, a=None, b=None, init_dual_a=None): - sink_kwargs = { - 'jit': True, - 'threshold': 0.001, - 'max_iterations': 10 ** 5, - 'potential_initializer': sort_init - } - geom_kwargs = {'epsilon': 0.01} - geom = PointCloud(x, y, **geom_kwargs) - out = sinkhorn(geom, a=a, b=b, init_dual_a=init_dual_a, **sink_kwargs) - return out - - @jax.jit - def run_sinkhorn(x, y, a=None, b=None, init_dual_a=None): - sink_kwargs = {'jit': True, 'threshold': 0.001, 'max_iterations': 10 ** 5} - geom_kwargs = {'epsilon': 0.01} - geom = PointCloud(x, y, **geom_kwargs) - out = sinkhorn(geom, a=a, b=b, init_dual_a=init_dual_a, **sink_kwargs) - return out - + def create_sorting_problem(self, n, epsilon=0.01): # definte ot problem - x_init = np.array([-1., 0, .22]) - y_init = np.array([0., 0, 1.1]) + x_init = jnp.array([-1., 0, .22]) + y_init = jnp.array([0., 0, 1.1]) - buf = 500 - np.random.seed(0) - x = np.concatenate([x_init, 10 + np.abs(np.random.normal(size=buf))]) * 5 - y = np.concatenate([y_init, 10 + np.abs(np.random.normal(size=buf))]) * 5 + x = jnp.concatenate([ + x_init, 10 + jnp.abs(jax.random.normal(self.rng, (n,))) + ]) * 5 + y = jnp.concatenate([ + y_init, 10 + jnp.abs(jax.random.normal(self.rng, (n,))) + ]) * 5 x = np.sort(x) y = np.sort(y) @@ -77,47 +80,77 @@ def run_sinkhorn(x, y, a=None, b=None, init_dual_a=None): a = np.ones(n) / n b = np.ones(m) / m - x_jnp, y_jnp = jnp.array(x.reshape(-1, 1)), jnp.array(y.reshape(-1, 1)) - - # run sinkhorn - sink_out = run_sinkhorn(x=x_jnp, y=y_jnp, a=a, b=b) - base_num_iter = jnp.sum(sink_out.errors > -1) - - sink_out = run_sinkhorn_sort_init(x=x_jnp, y=y_jnp, a=a, b=b) - sort_num_iter = jnp.sum(sink_out.errors > -1) - - # check initializer is better - self.assertTrue(base_num_iter >= sort_num_iter) + geom = pointcloud.PointCloud( + x.reshape(-1, 1), y.reshape(-1, 1), epsilon=epsilon + ) + ot_problem = linear_problems.LinearProblem(geom=geom, a=a, b=b) - def test_default_initializer(self): - """Tests default initializer""" + return ot_problem + def create_ot_problem(self, n, m, d, epsilon=0.01): # definte ot problem np.random.seed(0) - n, d = 1000, 2 + mu_a = np.array([-1, 1]) * 5 mu_b = np.array([0, 0]) - x = np.random.normal(size=n * d).reshape(n, d) + mu_a - y = np.random.normal(size=n * d).reshape(n, d) + mu_b + x = jax.random.normal(self.rng, (n, d)) + mu_a + y = jax.random.normal(self.rng, (m, d)) + mu_b - n = len(x) - m = len(y) a = np.ones(n) / n b = np.ones(m) / m x_jnp, y_jnp = jnp.array(x), jnp.array(y) - gaus_init = init_lib.GaussianInitializer() + geom = pointcloud.PointCloud(x_jnp, y_jnp, epsilon=epsilon) + + ot_problem = linear_problems.LinearProblem(geom=geom, a=a, b=b) + return ot_problem + + @parameterized.parameters([True], [False]) + def test_sorting_init(self, vector_min): + """Tests sorting dual initializer.""" + + n = 500 + epsilon = 0.01 + + ot_problem = self.create_sorting_problem(n=n, epsilon=epsilon) + # run sinkhorn + sink_out = run_sinkhorn( + x=ot_problem.geom.x, + y=ot_problem.geom.y, + a=ot_problem.a, + b=ot_problem.b, + epsilon=epsilon + ) + base_num_iter = jnp.sum(sink_out.errors > -1) + + sink_out = run_sinkhorn_sort_init( + x=ot_problem.geom.x, + y=ot_problem.geom.y, + a=ot_problem.a, + b=ot_problem.b, + epsilon=epsilon, + vector_min=vector_min + ) + sort_num_iter = jnp.sum(sink_out.errors > -1) + + # check initializer is better + self.assertTrue(base_num_iter >= sort_num_iter) - geom_kwargs = {'epsilon': 0.01} - geom = PointCloud(x_jnp, y_jnp, **geom_kwargs) + def test_default_initializer(self): + """Tests default initializer""" + n = 200 + m = 200 + d = 2 + epsilon = 0.01 - ot_problem = LinearProblem(geom=geom, a=a, b=b) - default_potential_a = init_lib.default_dual_a( + ot_problem = self.create_ot_problem(n, m, d) + + default_potential_a = init_lib._default_dual_a( ot_problem=ot_problem, lse_mode=True ) - default_potential_b = init_lib.default_dual_b( + default_potential_b = init_lib._default_dual_b( ot_problem=ot_problem, lse_mode=True ) @@ -126,8 +159,14 @@ def test_default_initializer(self): self.assertTrue((jnp.zeros(m) == default_potential_b).all()) # check gausian init returns 0 for non point cloud geometry - new_geom = Geometry(cost_matrix=geom.cost_matrix, **geom_kwargs) - ot_problem = LinearProblem(geom=new_geom, a=a, b=b) + # init initializer + gaus_init = init_lib.GaussianInitializer() + new_geom = geometry.Geometry( + cost_matrix=ot_problem.geom.cost_matrix, epsilon=epsilon + ) + ot_problem = linear_problems.LinearProblem( + geom=new_geom, a=ot_problem.a, b=ot_problem.b + ) init_potential_a = gaus_init.init_dual_a( ot_problem=ot_problem, lse_mode=True ) @@ -140,53 +179,31 @@ def test_default_initializer(self): def test_gaus_initializer(self): """Tests Gaussian initializer""" - - # init initializer - gaus_init = init_lib.GaussianInitializer() - - @jax.jit - def run_sinkhorn(x, y, a=None, b=None, init_dual_a=None): - sink_kwargs = {'jit': True, 'threshold': 0.001, 'max_iterations': 10 ** 5} - geom_kwargs = {'epsilon': 0.01} - geom = PointCloud(x, y, **geom_kwargs) - out = sinkhorn(geom, a=a, b=b, init_dual_a=init_dual_a, **sink_kwargs) - return out - - @jax.jit - def run_sinkhorn_gaus_init(x, y, a=None, b=None, init_dual_a=None): - sink_kwargs = { - 'jit': True, - 'threshold': 0.001, - 'max_iterations': 10 ** 5, - 'potential_initializer': gaus_init - } - - geom_kwargs = {'epsilon': 0.01} - geom = PointCloud(x, y, **geom_kwargs) - out = sinkhorn(geom, a=a, b=b, init_dual_a=init_dual_a, **sink_kwargs) - return out - # definte ot problem - np.random.seed(0) - n, d = 1000, 2 - mu_a = np.array([-1, 1]) * 5 - mu_b = np.array([0, 0]) - - x = np.random.normal(size=n * d).reshape(n, d) + mu_a - y = np.random.normal(size=n * d).reshape(n, d) + mu_b + n = 200 + m = 200 + d = 2 + epsilon = 0.01 - n = len(x) - m = len(y) - a = np.ones(n) / n - b = np.ones(m) / m - - x_jnp, y_jnp = jnp.array(x), jnp.array(y) + ot_problem = self.create_ot_problem(n, m, d) # run sinkhorn - sink_out = run_sinkhorn(x=x_jnp, y=y_jnp, a=a, b=b) + sink_out = run_sinkhorn( + x=ot_problem.geom.x, + y=ot_problem.geom.y, + a=ot_problem.a, + b=ot_problem.b, + epsilon=epsilon + ) base_num_iter = jnp.sum(sink_out.errors > -1) - sink_out = run_sinkhorn_gaus_init(x=x_jnp, y=y_jnp, a=a, b=b) + sink_out = run_sinkhorn_gaus_init( + x=ot_problem.geom.x, + y=ot_problem.geom.y, + a=ot_problem.a, + b=ot_problem.b, + epsilon=epsilon + ) gaus_num_iter = jnp.sum(sink_out.errors > -1) # check initializer is better From 86b32f55fcd3ed2e2e5050b528d0acb307828d63 Mon Sep 17 00:00:00 2001 From: James Thornton Date: Tue, 12 Jul 2022 18:32:59 -0700 Subject: [PATCH 32/46] add docs, rename sorting initializer --- docs/core.rst | 9 ++++++++ ott/core/initializers.py | 2 +- tests/core/initializers_test.py | 38 ++++++++++++++++----------------- tests/core/sinkhorn_test.py | 21 +++++++++++------- 4 files changed, 41 insertions(+), 29 deletions(-) diff --git a/docs/core.rst b/docs/core.rst index c39d9d3a6..8399207c4 100644 --- a/docs/core.rst +++ b/docs/core.rst @@ -31,6 +31,15 @@ Sinkhorn sinkhorn.Sinkhorn sinkhorn.SinkhornOutput +Sinkhorn Dual Initializers +-------- +.. autosummary:: + :toctree: _autosummary + + initializers.SinkhornInitializer + initializers.GaussianInitializer + initializers.SortingInitializer + Low-Rank Sinkhorn ----------------- .. autosummary:: diff --git a/ott/core/initializers.py b/ott/core/initializers.py index c1bea3aae..011f5b35f 100644 --- a/ott/core/initializers.py +++ b/ott/core/initializers.py @@ -194,7 +194,7 @@ def init_dual_a( return f_potential -class SortingInit(SinkhornInitializer): +class SortingInitializer(SinkhornInitializer): """Sorting Init class. DualSort algorithm from https://arxiv.org/abs/2206.07630, solve diff --git a/tests/core/initializers_test.py b/tests/core/initializers_test.py index 8ccda6e7f..6ab7145a8 100644 --- a/tests/core/initializers_test.py +++ b/tests/core/initializers_test.py @@ -15,7 +15,6 @@ import jax import jax.numpy as jnp -import jax.test_util import numpy as np from absl.testing import absltest, parameterized @@ -29,7 +28,7 @@ @jax.jit def run_sinkhorn_sort_init(x, y, a=None, b=None, epsilon=0.01, vector_min=True): geom = pointcloud.PointCloud(x, y, epsilon=epsilon) - sort_init = init_lib.SortingInit(vector_min=vector_min) + sort_init = init_lib.SortingInitializer(vector_min=vector_min) out = sinkhorn(geom, a=a, b=b, jit=True, potential_initializer=sort_init) return out @@ -64,13 +63,12 @@ def create_sorting_problem(self, n, epsilon=0.01): # definte ot problem x_init = jnp.array([-1., 0, .22]) y_init = jnp.array([0., 0, 1.1]) + x_rng, y_rng = jax.random.split(self.rng) - x = jnp.concatenate([ - x_init, 10 + jnp.abs(jax.random.normal(self.rng, (n,))) - ]) * 5 - y = jnp.concatenate([ - y_init, 10 + jnp.abs(jax.random.normal(self.rng, (n,))) - ]) * 5 + x = jnp.concatenate([x_init, 10 + jnp.abs(jax.random.normal(x_rng, + (n,)))]) * 5 + y = jnp.concatenate([y_init, 10 + jnp.abs(jax.random.normal(y_rng, + (n,)))]) * 5 x = np.sort(x) y = np.sort(y) @@ -89,13 +87,13 @@ def create_sorting_problem(self, n, epsilon=0.01): def create_ot_problem(self, n, m, d, epsilon=0.01): # definte ot problem - np.random.seed(0) + x_rng, y_rng = jax.random.split(self.rng) mu_a = np.array([-1, 1]) * 5 mu_b = np.array([0, 0]) - x = jax.random.normal(self.rng, (n, d)) + mu_a - y = jax.random.normal(self.rng, (m, d)) + mu_b + x = jax.random.normal(x_rng, (n, d)) + mu_a + y = jax.random.normal(y_rng, (m, d)) + mu_b a = np.ones(n) / n b = np.ones(m) / m @@ -111,8 +109,8 @@ def create_ot_problem(self, n, m, d, epsilon=0.01): def test_sorting_init(self, vector_min): """Tests sorting dual initializer.""" - n = 500 - epsilon = 0.01 + n = 100 + epsilon = 0.001 ot_problem = self.create_sorting_problem(n=n, epsilon=epsilon) # run sinkhorn @@ -135,8 +133,8 @@ def test_sorting_init(self, vector_min): ) sort_num_iter = jnp.sum(sink_out.errors > -1) - # check initializer is better - self.assertTrue(base_num_iter >= sort_num_iter) + # check initializer is better or equal + self.assertGreaterEqual(base_num_iter, sort_num_iter) def test_default_initializer(self): """Tests default initializer""" @@ -155,8 +153,8 @@ def test_default_initializer(self): ) # check default is 0 - self.assertTrue((jnp.zeros(n) == default_potential_a).all()) - self.assertTrue((jnp.zeros(m) == default_potential_b).all()) + np.testing.assert_array_equal(jnp.zeros(n), default_potential_a) + np.testing.assert_array_equal(jnp.zeros(m), default_potential_b) # check gausian init returns 0 for non point cloud geometry # init initializer @@ -174,8 +172,8 @@ def test_default_initializer(self): ot_problem=ot_problem, lse_mode=True ) - self.assertTrue((jnp.zeros(n) == init_potential_a).all()) - self.assertTrue((jnp.zeros(m) == init_potential_b).all()) + np.testing.assert_array_equal(jnp.zeros(n), init_potential_a) + np.testing.assert_array_equal(jnp.zeros(m), init_potential_b) def test_gaus_initializer(self): """Tests Gaussian initializer""" @@ -207,7 +205,7 @@ def test_gaus_initializer(self): gaus_num_iter = jnp.sum(sink_out.errors > -1) # check initializer is better - self.assertTrue(base_num_iter >= gaus_num_iter) + self.assertGreaterEqual(base_num_iter, gaus_num_iter) if __name__ == '__main__': diff --git a/tests/core/sinkhorn_test.py b/tests/core/sinkhorn_test.py index 0470547bb..abfd3b10b 100644 --- a/tests/core/sinkhorn_test.py +++ b/tests/core/sinkhorn_test.py @@ -448,8 +448,14 @@ def test_restart(self, lse_mode): default_a = jnp.ones_like(init_dual_a) default_b = jnp.ones_like(init_dual_b) - self.assertTrue((default_a != init_dual_a).all()) - self.assertTrue((default_b != init_dual_b).all()) + self.assertRaises( + AssertionError, + lambda: np.testing.assert_allclose(default_a, init_dual_a) + ) + self.assertRaises( + AssertionError, + lambda: np.testing.assert_allclose(default_b, init_dual_b) + ) out_restarted = sinkhorn.sinkhorn( geom, @@ -466,21 +472,20 @@ def test_restart(self, lse_mode): err_restarted = errors_restarted[errors_restarted > -1][-1] self.assertGreater(threshold, err_restarted) + # check we improve num iter num_iter_restarted = jnp.sum(errors_restarted > -1) - - # check we can only improve on error num_iter = jnp.sum(errors > -1) self.assertGreater(num_iter, num_iter_restarted) + # check only one iteration suffices when restarting with same data. + self.assertEqual(num_iter_restarted, 1) + # check we can only improve on error self.assertGreater(err + threshold, err_restarted) - # # check first error in restart does at least as well as previous best + # check first error in restart does at least as well as previous best self.assertGreater(err + threshold, errors_restarted[0]) - # check only one iteration suffices when restarting with same data. - self.assertEqual(num_iter_restarted, 1) - if __name__ == '__main__': absltest.main() From c12fea8580174f38779e4a3e562bd834d33fd5c9 Mon Sep 17 00:00:00 2001 From: James Thornton Date: Tue, 12 Jul 2022 18:39:02 -0700 Subject: [PATCH 33/46] fix merge conflict --- tests/core/sinkhorn_test.py | 201 ++++++++++++++---------------------- 1 file changed, 78 insertions(+), 123 deletions(-) diff --git a/tests/core/sinkhorn_test.py b/tests/core/sinkhorn_test.py index abfd3b10b..dc37c6f58 100644 --- a/tests/core/sinkhorn_test.py +++ b/tests/core/sinkhorn_test.py @@ -18,17 +18,17 @@ import jax import jax.numpy as jnp import numpy as np -from absl.testing import absltest, parameterized +import pytest -from ott.core import sinkhorn +from ott.core import linear_problems, sinkhorn from ott.geometry import costs, geometry, pointcloud -class SinkhornTest(parameterized.TestCase): +class TestSinkhorn: - def setUp(self): - super().setUp() - self.rng = jax.random.PRNGKey(0) + @pytest.fixture(autouse=True) + def initialize(self, rng: jnp.ndarray): + self.rng = rng self.dim = 4 self.n = 17 self.m = 29 @@ -44,39 +44,12 @@ def setUp(self): self.a = a / jnp.sum(a) self.b = b / jnp.sum(b) - @parameterized.named_parameters( - dict( - testcase_name='lse-Leh-mom', - lse_mode=True, - momentum=1.0, - chg_momentum_from=29, - inner_iterations=10, - norm_error=1 - ), - dict( - testcase_name='scal-Leh-mom', - lse_mode=False, - momentum=1.00, - chg_momentum_from=30, - inner_iterations=10, - norm_error=1 - ), - dict( - testcase_name='lse-Leh-1', - lse_mode=True, - momentum=1.0, - chg_momentum_from=60, - inner_iterations=1, - norm_error=2 - ), - dict( - testcase_name='lse-Leh-24', - lse_mode=True, - momentum=1.0, - chg_momentum_from=12, - inner_iterations=24, - norm_error=4, - ) + @pytest.mark.fast.with_args( + "lse_mode,momentum,chg_momentum_from,inner_iterations,norm_error", + [(True, 1.0, 29, 10, 1), (False, 1.0, 30, 10, 1), (True, 1.0, 60, 1, 2), + (True, 1.0, 12, 24, 4)], + ids=["lse-Leh-mom", "scal-Leh-mom", "lse-Leh-1", "lse-Leh-24"], + only_fast=[0, -1], ) def test_euclidean_point_cloud( self, lse_mode, momentum, chg_momentum_from, inner_iterations, norm_error @@ -97,11 +70,11 @@ def test_euclidean_point_cloud( ) errors = out.errors err = errors[errors > -1][-1] - self.assertGreater(threshold, err) + assert threshold > err other_geom = pointcloud.PointCloud(self.x, self.y + 0.3, epsilon=0.1) cost_other = out.cost_at_geom(other_geom) - self.assertIsNot(jnp.isnan(cost_other), True) + assert not jnp.isnan(cost_other) def test_autoepsilon(self): """Check that with auto-epsilon, dual potentials scale.""" @@ -136,14 +109,18 @@ def test_autoepsilon(self): np.testing.assert_allclose(f_1 * scale ** 2, f_2, rtol=1e-3, atol=1e-3) - @parameterized.product( - lse_mode=[True, False], + @pytest.mark.fast.with_args( + lse_mode=[False, True], init=[5], decay=[.9], tau_a=[1.0, .93], - tau_b=[1.0, .91] + tau_b=[1.0, .91], + only_fast=0 ) - def test_autoepsilon_with_decay(self, lse_mode, init, decay, tau_a, tau_b): + def test_autoepsilon_with_decay( + self, lse_mode: bool, init: float, decay: float, tau_a: float, + tau_b: float + ): """Check that variations in init/decay work, and result in same solution.""" geom = pointcloud.PointCloud(self.x, self.y, init=init, decay=decay) out_1 = sinkhorn.sinkhorn( @@ -153,6 +130,7 @@ def test_autoepsilon_with_decay(self, lse_mode, init, decay, tau_a, tau_b): tau_a=tau_a, tau_b=tau_b, jit=True, + lse_mode=lse_mode, threshold=1e-5 ) @@ -164,6 +142,7 @@ def test_autoepsilon_with_decay(self, lse_mode, init, decay, tau_a, tau_b): tau_a=tau_a, tau_b=tau_b, jit=True, + lse_mode=lse_mode, threshold=1e-5 ) # recenter if problem is balanced, since in that case solution is only @@ -176,6 +155,7 @@ def test_autoepsilon_with_decay(self, lse_mode, init, decay, tau_a, tau_b): atol=1e-4 ) + @pytest.mark.fast def test_euclidean_point_cloud_min_iter(self): """Testing the min_iterations parameter.""" threshold = 1e-3 @@ -189,11 +169,11 @@ def test_euclidean_point_cloud_min_iter(self): implicit_differentiation=False ).errors err = errors[jnp.logical_and(errors > -1, jnp.isfinite(errors))][-1] - self.assertGreater(threshold, err) - self.assertEqual(jnp.inf, errors[0]) - self.assertEqual(jnp.inf, errors[1]) - self.assertEqual(jnp.inf, errors[2]) - self.assertGreater(errors[3], 0) + assert threshold > err + assert errors[0] == jnp.inf + assert errors[1] == jnp.inf + assert errors[2] == jnp.inf + assert errors[3] > 0 def test_geom_vs_point_cloud(self): """Two point clouds vs. simple cost_matrix execution of sinkorn.""" @@ -206,46 +186,29 @@ def test_geom_vs_point_cloud(self): f_1 -= jnp.mean(f_1[jnp.isfinite(f_1)]) f_2 -= jnp.mean(f_2[jnp.isfinite(f_2)]) - np.testing.assert_allclose(f_1, f_2, rtol=1E-5, atol=1E-5) - - @parameterized.parameters([True], [False]) - def test_euclidean_point_cloud_parallel_weights(self, lse_mode): - """Two point clouds, parallel execution for batched histograms.""" - self.rng, *rngs = jax.random.split(self.rng, 2) - batch = 4 - a = jax.random.uniform(rngs[0], (batch, self.n)) - b = jax.random.uniform(rngs[0], (batch, self.m)) - a = a / jnp.sum(a, axis=1)[:, jnp.newaxis] - b = b / jnp.sum(b, axis=1)[:, jnp.newaxis] - threshold = 1e-3 - geom = pointcloud.PointCloud(self.x, self.y, epsilon=0.1, online=True) - errors = sinkhorn.sinkhorn( - geom, a=self.a, b=self.b, threshold=threshold, lse_mode=lse_mode - ).errors - err = errors[errors > -1][-1] - self.assertGreater(jnp.min(threshold - err), 0) + np.testing.assert_allclose(f_1, f_2, rtol=1e-5, atol=1e-5) - @parameterized.parameters([True], [False]) - def test_online_euclidean_point_cloud(self, lse_mode): + @pytest.mark.parametrize("lse_mode", [False, True]) + def test_online_euclidean_point_cloud(self, lse_mode: bool): """Testing the online way to handle geometry.""" threshold = 1e-3 - geom = pointcloud.PointCloud(self.x, self.y, epsilon=0.1, online=True) + geom = pointcloud.PointCloud(self.x, self.y, epsilon=0.1, batch_size=5) errors = sinkhorn.sinkhorn( geom, a=self.a, b=self.b, threshold=threshold, lse_mode=lse_mode ).errors err = errors[errors > -1][-1] - self.assertGreater(threshold, err) + assert threshold > err - @parameterized.parameters([True], [False]) - def test_online_vs_batch_euclidean_point_cloud(self, lse_mode): + @pytest.mark.fast.with_args("lse_mode", [False, True], only_fast=0) + def test_online_vs_batch_euclidean_point_cloud(self, lse_mode: bool): """Comparing online vs batch geometry.""" threshold = 1e-3 eps = 0.1 online_geom = pointcloud.PointCloud( - self.x, self.y, epsilon=eps, online=True + self.x, self.y, epsilon=eps, batch_size=7 ) online_geom_euc = pointcloud.PointCloud( - self.x, self.y, cost_fn=costs.Euclidean(), epsilon=eps, online=True + self.x, self.y, cost_fn=costs.Euclidean(), epsilon=eps, batch_size=10 ) batch_geom = pointcloud.PointCloud(self.x, self.y, epsilon=eps) @@ -280,8 +243,8 @@ def test_online_vs_batch_euclidean_point_cloud(self, lse_mode): np.testing.assert_allclose( online_geom.transport_from_potentials(out_online.f, out_online.g), batch_geom.transport_from_potentials(out_batch.f, out_batch.g), - rtol=1E-5, - atol=1E-5 + rtol=1e-5, + atol=1e-5 ) np.testing.assert_allclose( @@ -291,8 +254,8 @@ def test_online_vs_batch_euclidean_point_cloud(self, lse_mode): batch_geom_euc.transport_from_potentials( out_batch_euc.f, out_batch_euc.g ), - rtol=1E-5, - atol=1E-5 + rtol=1e-5, + atol=1e-5 ) np.testing.assert_allclose( @@ -300,8 +263,8 @@ def test_online_vs_batch_euclidean_point_cloud(self, lse_mode): batch_geom_euc.transport_from_potentials( out_batch_euc.f, out_batch_euc.g ), - rtol=1E-5, - atol=1E-5 + rtol=1e-5, + atol=1e-5 ) def test_apply_transport_geometry_from_potentials(self): @@ -324,8 +287,8 @@ def test_apply_transport_geometry_from_potentials(self): # test with lse_mode and online = True / False for j, lse_mode in enumerate([True, False]): - for i, online in enumerate([True, False]): - geom = pointcloud.PointCloud(x, y, online=online, epsilon=0.2) + for i, batch_size in enumerate([16, None]): + geom = pointcloud.PointCloud(x, y, batch_size=batch_size, epsilon=0.2) sink = sinkhorn.sinkhorn(geom, a, b, lse_mode=lse_mode) transport_t_vec_a[i + 2 * j] = geom.apply_transport_from_potentials( @@ -378,8 +341,8 @@ def test_apply_transport_geometry_from_scalings(self): # test with lse_mode and online = True / False for j, lse_mode in enumerate([True, False]): - for i, online in enumerate([True, False]): - geom = pointcloud.PointCloud(x, y, online=online, epsilon=0.2) + for i, batch_size in enumerate([64, None]): + geom = pointcloud.PointCloud(x, y, batch_size=batch_size, epsilon=0.2) sink = sinkhorn.sinkhorn(geom, a, b, lse_mode=lse_mode) u = geom.scaling_from_potential(sink.f) @@ -406,7 +369,10 @@ def test_apply_transport_geometry_from_scalings(self): rtol=1e-3, atol=1e-3 ) - self.assertIsNot(jnp.any(jnp.isnan(transport_t_vec_a[i + 2 * j])), True) + np.testing.assert_array_equal( + jnp.isnan(transport_t_vec_a[i + 2 * j]), False + ) + for i in range(4): np.testing.assert_allclose( transport_vec_b[i], transport_vec_b[0], rtol=1e-3, atol=1e-3 @@ -415,8 +381,8 @@ def test_apply_transport_geometry_from_scalings(self): transport_t_vec_a[i], transport_t_vec_a[0], rtol=1e-3, atol=1e-3 ) - @parameterized.parameters([True], [False]) - def test_restart(self, lse_mode): + @pytest.mark.parametrize("lse_mode", [False, True]) + def test_restart(self, lse_mode: bool): """Two point clouds, tested with various parameters.""" threshold = 1e-4 geom = pointcloud.PointCloud(self.x, self.y, epsilon=0.01) @@ -430,7 +396,7 @@ def test_restart(self, lse_mode): ) errors = out.errors err = errors[errors > -1][-1] - self.assertGreater(threshold, err) + assert threshold > err # recover solution from previous and ensure faster convergence. if lse_mode: @@ -440,23 +406,6 @@ def test_restart(self, lse_mode): geom.scaling_from_potential(out.f), geom.scaling_from_potential(out.g) ) - - if lse_mode: - default_a = jnp.zeros_like(init_dual_a) - default_b = jnp.zeros_like(init_dual_b) - else: - default_a = jnp.ones_like(init_dual_a) - default_b = jnp.ones_like(init_dual_b) - - self.assertRaises( - AssertionError, - lambda: np.testing.assert_allclose(default_a, init_dual_a) - ) - self.assertRaises( - AssertionError, - lambda: np.testing.assert_allclose(default_b, init_dual_b) - ) - out_restarted = sinkhorn.sinkhorn( geom, a=self.a, @@ -467,25 +416,31 @@ def test_restart(self, lse_mode): init_dual_b=init_dual_b, inner_iterations=1 ) - errors_restarted = out_restarted.errors err_restarted = errors_restarted[errors_restarted > -1][-1] - self.assertGreater(threshold, err_restarted) + assert threshold > err_restarted - # check we improve num iter num_iter_restarted = jnp.sum(errors_restarted > -1) - num_iter = jnp.sum(errors > -1) - self.assertGreater(num_iter, num_iter_restarted) - - # check only one iteration suffices when restarting with same data. - self.assertEqual(num_iter_restarted, 1) - # check we can only improve on error - self.assertGreater(err + threshold, err_restarted) - + assert err > err_restarted # check first error in restart does at least as well as previous best - self.assertGreater(err + threshold, errors_restarted[0]) - - -if __name__ == '__main__': - absltest.main() + assert err > errors_restarted[0] + # check only one iteration suffices when restarting with same data. + assert num_iter_restarted == 1 + + @pytest.mark.limit_memory("90 MB") + @pytest.mark.fast.with_args("batch_size", [500, 1000], only_fast=0) + def test_sinkhorn_online_memory(self, batch_size: int): + # offline: Total memory allocated: 240.1MiB + # online (500): Total memory allocated: 33.4MiB + # online (1000): Total memory allocated: 45.6MiB + rngs = jax.random.split(jax.random.PRNGKey(0), 4) + n, m = 5000, 4000 + x = jax.random.uniform(rngs[0], (n, 2)) + y = jax.random.uniform(rngs[1], (m, 2)) + geom = pointcloud.PointCloud(x, y, batch_size=batch_size, epsilon=1) + problem = linear_problems.LinearProblem(geom) + solver = sinkhorn.Sinkhorn() + + out = solver(problem) + assert out.converged \ No newline at end of file From 1c56799cd1411a8ef189290d4270a500cdcdee3f Mon Sep 17 00:00:00 2001 From: James Thornton Date: Tue, 12 Jul 2022 18:59:16 -0700 Subject: [PATCH 34/46] resolve test errors in sinkhorn test --- tests/core/sinkhorn_test.py | 183 ++++++++++++++++++++++-------------- 1 file changed, 111 insertions(+), 72 deletions(-) diff --git a/tests/core/sinkhorn_test.py b/tests/core/sinkhorn_test.py index dc37c6f58..5aae11d81 100644 --- a/tests/core/sinkhorn_test.py +++ b/tests/core/sinkhorn_test.py @@ -18,17 +18,17 @@ import jax import jax.numpy as jnp import numpy as np -import pytest +from absl.testing import absltest, parameterized -from ott.core import linear_problems, sinkhorn +from ott.core import sinkhorn from ott.geometry import costs, geometry, pointcloud -class TestSinkhorn: +class SinkhornTest(parameterized.TestCase): - @pytest.fixture(autouse=True) - def initialize(self, rng: jnp.ndarray): - self.rng = rng + def setUp(self): + super().setUp() + self.rng = jax.random.PRNGKey(0) self.dim = 4 self.n = 17 self.m = 29 @@ -44,12 +44,39 @@ def initialize(self, rng: jnp.ndarray): self.a = a / jnp.sum(a) self.b = b / jnp.sum(b) - @pytest.mark.fast.with_args( - "lse_mode,momentum,chg_momentum_from,inner_iterations,norm_error", - [(True, 1.0, 29, 10, 1), (False, 1.0, 30, 10, 1), (True, 1.0, 60, 1, 2), - (True, 1.0, 12, 24, 4)], - ids=["lse-Leh-mom", "scal-Leh-mom", "lse-Leh-1", "lse-Leh-24"], - only_fast=[0, -1], + @parameterized.named_parameters( + dict( + testcase_name='lse-Leh-mom', + lse_mode=True, + momentum=1.0, + chg_momentum_from=29, + inner_iterations=10, + norm_error=1 + ), + dict( + testcase_name='scal-Leh-mom', + lse_mode=False, + momentum=1.00, + chg_momentum_from=30, + inner_iterations=10, + norm_error=1 + ), + dict( + testcase_name='lse-Leh-1', + lse_mode=True, + momentum=1.0, + chg_momentum_from=60, + inner_iterations=1, + norm_error=2 + ), + dict( + testcase_name='lse-Leh-24', + lse_mode=True, + momentum=1.0, + chg_momentum_from=12, + inner_iterations=24, + norm_error=4, + ) ) def test_euclidean_point_cloud( self, lse_mode, momentum, chg_momentum_from, inner_iterations, norm_error @@ -70,11 +97,11 @@ def test_euclidean_point_cloud( ) errors = out.errors err = errors[errors > -1][-1] - assert threshold > err + self.assertGreater(threshold, err) other_geom = pointcloud.PointCloud(self.x, self.y + 0.3, epsilon=0.1) cost_other = out.cost_at_geom(other_geom) - assert not jnp.isnan(cost_other) + self.assertIsNot(jnp.isnan(cost_other), True) def test_autoepsilon(self): """Check that with auto-epsilon, dual potentials scale.""" @@ -109,18 +136,14 @@ def test_autoepsilon(self): np.testing.assert_allclose(f_1 * scale ** 2, f_2, rtol=1e-3, atol=1e-3) - @pytest.mark.fast.with_args( - lse_mode=[False, True], + @parameterized.product( + lse_mode=[True, False], init=[5], decay=[.9], tau_a=[1.0, .93], - tau_b=[1.0, .91], - only_fast=0 + tau_b=[1.0, .91] ) - def test_autoepsilon_with_decay( - self, lse_mode: bool, init: float, decay: float, tau_a: float, - tau_b: float - ): + def test_autoepsilon_with_decay(self, lse_mode, init, decay, tau_a, tau_b): """Check that variations in init/decay work, and result in same solution.""" geom = pointcloud.PointCloud(self.x, self.y, init=init, decay=decay) out_1 = sinkhorn.sinkhorn( @@ -130,7 +153,6 @@ def test_autoepsilon_with_decay( tau_a=tau_a, tau_b=tau_b, jit=True, - lse_mode=lse_mode, threshold=1e-5 ) @@ -142,7 +164,6 @@ def test_autoepsilon_with_decay( tau_a=tau_a, tau_b=tau_b, jit=True, - lse_mode=lse_mode, threshold=1e-5 ) # recenter if problem is balanced, since in that case solution is only @@ -155,7 +176,6 @@ def test_autoepsilon_with_decay( atol=1e-4 ) - @pytest.mark.fast def test_euclidean_point_cloud_min_iter(self): """Testing the min_iterations parameter.""" threshold = 1e-3 @@ -169,11 +189,11 @@ def test_euclidean_point_cloud_min_iter(self): implicit_differentiation=False ).errors err = errors[jnp.logical_and(errors > -1, jnp.isfinite(errors))][-1] - assert threshold > err - assert errors[0] == jnp.inf - assert errors[1] == jnp.inf - assert errors[2] == jnp.inf - assert errors[3] > 0 + self.assertGreater(threshold, err) + self.assertEqual(jnp.inf, errors[0]) + self.assertEqual(jnp.inf, errors[1]) + self.assertEqual(jnp.inf, errors[2]) + self.assertGreater(errors[3], 0) def test_geom_vs_point_cloud(self): """Two point clouds vs. simple cost_matrix execution of sinkorn.""" @@ -186,29 +206,46 @@ def test_geom_vs_point_cloud(self): f_1 -= jnp.mean(f_1[jnp.isfinite(f_1)]) f_2 -= jnp.mean(f_2[jnp.isfinite(f_2)]) - np.testing.assert_allclose(f_1, f_2, rtol=1e-5, atol=1e-5) + np.testing.assert_allclose(f_1, f_2, rtol=1E-5, atol=1E-5) + + @parameterized.parameters([True], [False]) + def test_euclidean_point_cloud_parallel_weights(self, lse_mode): + """Two point clouds, parallel execution for batched histograms.""" + self.rng, *rngs = jax.random.split(self.rng, 2) + batch = 4 + a = jax.random.uniform(rngs[0], (batch, self.n)) + b = jax.random.uniform(rngs[0], (batch, self.m)) + a = a / jnp.sum(a, axis=1)[:, jnp.newaxis] + b = b / jnp.sum(b, axis=1)[:, jnp.newaxis] + threshold = 1e-3 + geom = pointcloud.PointCloud(self.x, self.y, epsilon=0.1, online=True) + errors = sinkhorn.sinkhorn( + geom, a=self.a, b=self.b, threshold=threshold, lse_mode=lse_mode + ).errors + err = errors[errors > -1][-1] + self.assertGreater(jnp.min(threshold - err), 0) - @pytest.mark.parametrize("lse_mode", [False, True]) - def test_online_euclidean_point_cloud(self, lse_mode: bool): + @parameterized.parameters([True], [False]) + def test_online_euclidean_point_cloud(self, lse_mode): """Testing the online way to handle geometry.""" threshold = 1e-3 - geom = pointcloud.PointCloud(self.x, self.y, epsilon=0.1, batch_size=5) + geom = pointcloud.PointCloud(self.x, self.y, epsilon=0.1, online=True) errors = sinkhorn.sinkhorn( geom, a=self.a, b=self.b, threshold=threshold, lse_mode=lse_mode ).errors err = errors[errors > -1][-1] - assert threshold > err + self.assertGreater(threshold, err) - @pytest.mark.fast.with_args("lse_mode", [False, True], only_fast=0) - def test_online_vs_batch_euclidean_point_cloud(self, lse_mode: bool): + @parameterized.parameters([True], [False]) + def test_online_vs_batch_euclidean_point_cloud(self, lse_mode): """Comparing online vs batch geometry.""" threshold = 1e-3 eps = 0.1 online_geom = pointcloud.PointCloud( - self.x, self.y, epsilon=eps, batch_size=7 + self.x, self.y, epsilon=eps, online=True ) online_geom_euc = pointcloud.PointCloud( - self.x, self.y, cost_fn=costs.Euclidean(), epsilon=eps, batch_size=10 + self.x, self.y, cost_fn=costs.Euclidean(), epsilon=eps, online=True ) batch_geom = pointcloud.PointCloud(self.x, self.y, epsilon=eps) @@ -243,8 +280,8 @@ def test_online_vs_batch_euclidean_point_cloud(self, lse_mode: bool): np.testing.assert_allclose( online_geom.transport_from_potentials(out_online.f, out_online.g), batch_geom.transport_from_potentials(out_batch.f, out_batch.g), - rtol=1e-5, - atol=1e-5 + rtol=1E-5, + atol=1E-5 ) np.testing.assert_allclose( @@ -254,8 +291,8 @@ def test_online_vs_batch_euclidean_point_cloud(self, lse_mode: bool): batch_geom_euc.transport_from_potentials( out_batch_euc.f, out_batch_euc.g ), - rtol=1e-5, - atol=1e-5 + rtol=1E-5, + atol=1E-5 ) np.testing.assert_allclose( @@ -263,8 +300,8 @@ def test_online_vs_batch_euclidean_point_cloud(self, lse_mode: bool): batch_geom_euc.transport_from_potentials( out_batch_euc.f, out_batch_euc.g ), - rtol=1e-5, - atol=1e-5 + rtol=1E-5, + atol=1E-5 ) def test_apply_transport_geometry_from_potentials(self): @@ -287,8 +324,8 @@ def test_apply_transport_geometry_from_potentials(self): # test with lse_mode and online = True / False for j, lse_mode in enumerate([True, False]): - for i, batch_size in enumerate([16, None]): - geom = pointcloud.PointCloud(x, y, batch_size=batch_size, epsilon=0.2) + for i, online in enumerate([True, False]): + geom = pointcloud.PointCloud(x, y, online=online, epsilon=0.2) sink = sinkhorn.sinkhorn(geom, a, b, lse_mode=lse_mode) transport_t_vec_a[i + 2 * j] = geom.apply_transport_from_potentials( @@ -341,8 +378,8 @@ def test_apply_transport_geometry_from_scalings(self): # test with lse_mode and online = True / False for j, lse_mode in enumerate([True, False]): - for i, batch_size in enumerate([64, None]): - geom = pointcloud.PointCloud(x, y, batch_size=batch_size, epsilon=0.2) + for i, online in enumerate([True, False]): + geom = pointcloud.PointCloud(x, y, online=online, epsilon=0.2) sink = sinkhorn.sinkhorn(geom, a, b, lse_mode=lse_mode) u = geom.scaling_from_potential(sink.f) @@ -369,10 +406,7 @@ def test_apply_transport_geometry_from_scalings(self): rtol=1e-3, atol=1e-3 ) - np.testing.assert_array_equal( - jnp.isnan(transport_t_vec_a[i + 2 * j]), False - ) - + self.assertIsNot(jnp.any(jnp.isnan(transport_t_vec_a[i + 2 * j])), True) for i in range(4): np.testing.assert_allclose( transport_vec_b[i], transport_vec_b[0], rtol=1e-3, atol=1e-3 @@ -381,8 +415,8 @@ def test_apply_transport_geometry_from_scalings(self): transport_t_vec_a[i], transport_t_vec_a[0], rtol=1e-3, atol=1e-3 ) - @pytest.mark.parametrize("lse_mode", [False, True]) - def test_restart(self, lse_mode: bool): + @parameterized.parameters([True], [False]) + def test_restart(self, lse_mode): """Two point clouds, tested with various parameters.""" threshold = 1e-4 geom = pointcloud.PointCloud(self.x, self.y, epsilon=0.01) @@ -396,7 +430,7 @@ def test_restart(self, lse_mode: bool): ) errors = out.errors err = errors[errors > -1][-1] - assert threshold > err + self.assertGreater(threshold, err) # recover solution from previous and ensure faster convergence. if lse_mode: @@ -406,6 +440,23 @@ def test_restart(self, lse_mode: bool): geom.scaling_from_potential(out.f), geom.scaling_from_potential(out.g) ) + + if lse_mode: + default_a = jnp.zeros_like(init_dual_a) + default_b = jnp.zeros_like(init_dual_b) + else: + default_a = jnp.ones_like(init_dual_a) + default_b = jnp.ones_like(init_dual_b) + + self.assertRaises( + AssertionError, + lambda: np.testing.assert_allclose(default_a, init_dual_a) + ) + self.assertRaises( + AssertionError, + lambda: np.testing.assert_allclose(default_b, init_dual_b) + ) + out_restarted = sinkhorn.sinkhorn( geom, a=self.a, @@ -416,6 +467,7 @@ def test_restart(self, lse_mode: bool): init_dual_b=init_dual_b, inner_iterations=1 ) + errors_restarted = out_restarted.errors err_restarted = errors_restarted[errors_restarted > -1][-1] assert threshold > err_restarted @@ -428,19 +480,6 @@ def test_restart(self, lse_mode: bool): # check only one iteration suffices when restarting with same data. assert num_iter_restarted == 1 - @pytest.mark.limit_memory("90 MB") - @pytest.mark.fast.with_args("batch_size", [500, 1000], only_fast=0) - def test_sinkhorn_online_memory(self, batch_size: int): - # offline: Total memory allocated: 240.1MiB - # online (500): Total memory allocated: 33.4MiB - # online (1000): Total memory allocated: 45.6MiB - rngs = jax.random.split(jax.random.PRNGKey(0), 4) - n, m = 5000, 4000 - x = jax.random.uniform(rngs[0], (n, 2)) - y = jax.random.uniform(rngs[1], (m, 2)) - geom = pointcloud.PointCloud(x, y, batch_size=batch_size, epsilon=1) - problem = linear_problems.LinearProblem(geom) - solver = sinkhorn.Sinkhorn() - - out = solver(problem) - assert out.converged \ No newline at end of file + +if __name__ == '__main__': + absltest.main() From 161a67a75e855e05c315693d76cd6887b100ce7f Mon Sep 17 00:00:00 2001 From: James Thornton Date: Thu, 14 Jul 2022 15:09:27 -0700 Subject: [PATCH 35/46] incorporate feedback, update tests to pytest, change docstrings, introduce defaultinit class --- docs/core.rst | 2 +- ott/core/initializers.py | 243 +++++++++++--------------------- ott/core/sinkhorn.py | 16 +-- tests/core/initializers_test.py | 138 +++++++++--------- tests/core/sinkhorn_test.py | 13 +- 5 files changed, 166 insertions(+), 246 deletions(-) diff --git a/docs/core.rst b/docs/core.rst index 446f37dc5..3f84464f5 100644 --- a/docs/core.rst +++ b/docs/core.rst @@ -32,7 +32,7 @@ Sinkhorn sinkhorn.SinkhornOutput Sinkhorn Dual Initializers --------- +-------------------------- .. autosummary:: :toctree: _autosummary diff --git a/ott/core/initializers.py b/ott/core/initializers.py index 011f5b35f..139aa8db3 100644 --- a/ott/core/initializers.py +++ b/ott/core/initializers.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Sinkhorn initializers.""" -from typing import Optional, Tuple +from typing import Optional import jax import jax.numpy as jnp @@ -21,139 +21,71 @@ from ott.geometry import pointcloud -def _default_dual_a( - ot_problem: linear_problems.LinearProblem, lse_mode: bool -) -> jnp.ndarray: - """Return dual potential vector, f. - - Args: - ot_problem: - lse_mode: Return potentials if true, scaling if false. - - Returns: - potential f, array of size n - """ - a = ot_problem.a - init_dual_a = jnp.zeros_like(a) if lse_mode else jnp.ones_like(a) - return init_dual_a - - -def _default_dual_b( - ot_problem: linear_problems.LinearProblem, lse_mode: bool -) -> jnp.ndarray: - """Return dual potential vector, g. - - Args: - ot_problem: - lse_mode: Return potentials if true, scaling if false. - - Returns: - potential g, array of size m - """ - b = ot_problem.b - init_dual_b = jnp.zeros_like(b) if lse_mode else jnp.ones_like(b) - return init_dual_b - - -def _remove_single_weight_potential( - weights: jnp.ndarray, init_dual: jnp.ndarray, lse_mode: bool -) -> Tuple[jnp.ndarray]: - """Cancel dual variables for zero weights. - - Args: - weights: array of probability masses - init_dual: dual potential array - lse_mode: Return potentials if true, scaling if false. - Returns: - potential - """ - return jnp.where(weights > 0, init_dual, -jnp.inf if lse_mode else 0.0) - - -def remove_weight_potentials( - weights_a: jnp.ndarray, weights_b: jnp.ndarray, init_dual_a: jnp.ndarray, - init_dual_b: jnp.ndarray, lse_mode: bool -) -> Tuple[jnp.ndarray]: - """Cancel dual variables for zero weights. - - Args: - weights_a: array of probability masses, array of size n - weights_b: array of probability masses, array of size m - init_dual_a: potential f, array of size n - init_dual_b: potential g, array of size m - lse_mode: Return potentials if true, scaling if false. +class SinkhornInitializer: - Returns: - potentials (f,g) - """ - init_dual_a = _remove_single_weight_potential( - weights_a, init_dual_a, lse_mode - ) - init_dual_b = _remove_single_weight_potential( - weights_b, init_dual_b, lse_mode - ) - return init_dual_a, init_dual_b + def init_dual_a( + self, ot_problem: linear_problems.LinearProblem, lse_mode: bool + ) -> jnp.ndarray: + """Initialization for Sinkhorn potential/ scaling f_u.""" + def init_dual_b( + self, ot_problem: linear_problems.LinearProblem, lse_mode: bool + ) -> jnp.ndarray: + """Initialization for Sinkhorn potential/ scaling g_v.""" -class SinkhornInitializer: - """Initialization of Sinkhorn dual potentials. - Args: - ot_problem: OT problem between discrete distributions of size n and m. - lse_mode: Return potential if true, scaling if false. +class DefaultInitializer(SinkhornInitializer): + """Default Initialization of Sinkhorn dual potentials/ primal scalings. - Returns: - dual potential, array of size n """ def init_dual_a( self, ot_problem: linear_problems.LinearProblem, lse_mode: bool ) -> jnp.ndarray: - """Initialzation for Sinkhorn potential f. + """Initialization for Sinkhorn potential/ scaling f_u. Args: ot_problem: OT problem between discrete distributions of size n and m. lse_mode: Return potential if true, scaling if false. Returns: - dual potential, array of size n + potential/ scaling, array of size n """ - return _default_dual_a(ot_problem=ot_problem, lse_mode=lse_mode) + a = ot_problem.a + init_dual_a = jnp.zeros_like(a) if lse_mode else jnp.ones_like(a) + return init_dual_a def init_dual_b( self, ot_problem: linear_problems.LinearProblem, lse_mode: bool ) -> jnp.ndarray: - """Initialzation for Sinkhorn potential g. + """Initialization for Sinkhorn potential/ scaling g_v. Args: ot_problem: OT problem between discrete distributions of size n and m. lse_mode: Return potential if true, scaling if false. Returns: - dual potential, array of size m + potential/ scaling, array of size m """ - return _default_dual_b(ot_problem=ot_problem, lse_mode=lse_mode) + b = ot_problem.b + init_dual_b = jnp.zeros_like(b) if lse_mode else jnp.ones_like(b) + return init_dual_b -class GaussianInitializer(SinkhornInitializer): +class GaussianInitializer(DefaultInitializer): """GaussianInitializer. From https://arxiv.org/abs/2206.07630. Compute Gaussian approximations of each pointcloud, then compute closed from - Kantorovic potential betwen Gaussian approximations using Brenier's theorem - (adapt convex/ Brenier potential to Kantoroic). Use this Gaussian potential to - initialize Sinkhorn potentials. + Kantorovich potential betwen Gaussian approximations using Brenier's theorem + (adapt convex/ Brenier potential to Kantorovich). Use this Gaussian potential to + initialize Sinkhorn potentials/ scalings. - Args: - stop_gradient: Defaults to True. """ - def __init__(self, stop_gradient: bool = True) -> None: - + def __init__(self): super().__init__() - self.stop_gradient = stop_gradient - def init_dual_a( self, ot_problem: linear_problems.LinearProblem, @@ -163,38 +95,34 @@ def init_dual_a( Args: ot_problem: OT problem description with geometry and weights. - init_f: Pre dual sort initialization, when none sets entries as 0. lse_mode: Return potential if true, scaling if false. Returns: - potential f, array of size n. + potential/ scaling f_u, array of size n. """ # import Gaussian here due to circular imports from ott.tools.gaussian_mixture import gaussian if not isinstance(ot_problem.geom, pointcloud.PointCloud): # warning that init not applied - return _default_dual_a(ot_problem, lse_mode) + return super().init_dual_a(ot_problem, lse_mode) else: x, y = ot_problem.geom.x, ot_problem.geom.y a, b = ot_problem.a, ot_problem.b - if self.stop_gradient: - x, y = jax.lax.stop_gradient(x), jax.lax.stop_gradient(y) - a, b = jax.lax.stop_gradient(a), jax.lax.stop_gradient(b) gaussian_a = gaussian.Gaussian.from_samples(x, weights=a) gaussian_b = gaussian.Gaussian.from_samples(y, weights=b) # Brenier potential for cost ||x-y||^2/2, multiply by two for ||x-y||^2 f_potential = 2 * gaussian_a.f_potential(dest=gaussian_b, points=x) f_potential = f_potential - jnp.mean(f_potential) - f_potential = f_potential if lse_mode else ot_problem.scaling_from_potential( + f_u = f_potential if lse_mode else ot_problem.scaling_from_potential( f_potential ) - return f_potential + return f_u -class SortingInitializer(SinkhornInitializer): +class SortingInitializer(DefaultInitializer): """Sorting Init class. DualSort algorithm from https://arxiv.org/abs/2206.07630, solve @@ -203,64 +131,26 @@ class SortingInitializer(SinkhornInitializer): regularized potential Args: - vector_min: Use vectorized inner loop if true. Defaults to True. + vectorized_update: Use vectorized inner loop if true. Defaults to True. tol: DualSort convergence threshold. Defaults to 1e-2. max_iter: Max DualSort steps. Defaults to 100. - stop_gradient: Do not trace gradient. Defaults to True. """ def __init__( self, - vector_min: bool = True, - tol: float = 1e-2, - max_iter: int = 100, - stop_gradient: bool = True - ) -> None: + vectorized_update: bool = True, + tolerance: float = 1e-2, + max_iter: int = 100 + ): super().__init__() - self.tolerance = tol - self.stop_gradient = stop_gradient + self.tolerance = tolerance self.max_iter = max_iter self.update_fn = lambda f, mod_cost: jax.lax.cond( - vector_min, self.vectorized_update, self.coordinate_update, f, mod_cost + vectorized_update, _vectorized_update, _coordinate_update, f, mod_cost ) - def vectorized_update( - self, f: jnp.ndarray, modified_cost: jnp.ndarray - ) -> jnp.ndarray: - """Inner loop DualSort Update. - - Args: - f : potential f, array of size n. - modified_cost: cost matrix minus diagonal column-wise. - - Returns: - updated potential vector, f. - """ - f = jnp.min(modified_cost + f[None, :], axis=1) - return f - - def coordinate_update( - self, f: jnp.ndarray, modified_cost: jnp.ndarray - ) -> jnp.ndarray: - """Coordinate-wise updates within inner loop. - - Args: - f: potential f, array of size n. - modified_cost: cost matrix minus diagonal column-wise. - - Returns: - updated potential vector, f. - """ - - def body_fn(i, f): - new_f = jnp.min(modified_cost[i, :] + f) - f = f.at[i].set(new_f) - return f - - return jax.lax.fori_loop(0, len(f), body_fn, f) - def init_sorting_dual( self, modified_cost: jnp.ndarray, f_potential: jnp.ndarray ) -> jnp.ndarray: @@ -276,10 +166,10 @@ def init_sorting_dual( def body_fn(state): prev_f, _, it = state - f_potential = self.update_fn(prev_f, modified_cost) - diff = jnp.sum((f_potential - prev_f) ** 2) + new_f = self.update_fn(prev_f, modified_cost) + diff = jnp.sum((new_f - prev_f) ** 2) it += 1 - return f_potential, diff, it + return new_f, diff, it def cond_fn(state): _, diff, it = state @@ -306,15 +196,13 @@ def init_dual_a( Args: ot_problem: OT problem. lse_mode: Return potential if true, scaling if false. - init_f: potential f, array of size n. + init_f: potential f, array of size n. This is the starting potential, + which is then updated to make the init potential, so an init of an init. Returns: - potential f, array of size n. + potential/ scaling f_u, array of size n. """ cost_matrix = ot_problem.geom.cost_matrix - if self.stop_gradient: - cost_matrix = jax.lax.stop_gradient(cost_matrix) - modified_cost = cost_matrix - jnp.diag(cost_matrix)[None, :] n = cost_matrix.shape[0] @@ -323,8 +211,45 @@ def init_dual_a( f_potential = self.init_sorting_dual(modified_cost, f_potential) f_potential = f_potential - jnp.mean(f_potential) - f_potential = f_potential if lse_mode else ot_problem.scaling_from_potential( + f_u = f_potential if lse_mode else ot_problem.scaling_from_potential( f_potential ) - return f_potential + return f_u + + +def _vectorized_update( + f: jnp.ndarray, modified_cost: jnp.ndarray +) -> jnp.ndarray: + """Inner loop DualSort Update. + + Args: + f : potential f, array of size n. + modified_cost: cost matrix minus diagonal column-wise. + + Returns: + updated potential vector, f. + """ + f = jnp.min(modified_cost + f[None, :], axis=1) + return f + + +def _coordinate_update( + f: jnp.ndarray, modified_cost: jnp.ndarray +) -> jnp.ndarray: + """Coordinate-wise updates within inner loop. + + Args: + f: potential f, array of size n. + modified_cost: cost matrix minus diagonal column-wise. + + Returns: + updated potential vector, f. + """ + + def body_fn(i, f): + new_f = jnp.min(modified_cost[i, :] + f) + f = f.at[i].set(new_f) + return f + + return jax.lax.fori_loop(0, len(f), body_fn, f) diff --git a/ott/core/sinkhorn.py b/ott/core/sinkhorn.py index 9b29c5be5..bb2e05a3a 100644 --- a/ott/core/sinkhorn.py +++ b/ott/core/sinkhorn.py @@ -351,7 +351,7 @@ def __init__( implicit_diff: Optional[implicit_lib.ImplicitDiff ] = implicit_lib.ImplicitDiff(), # noqa: E124 potential_initializer: init_lib.SinkhornInitializer = init_lib - .SinkhornInitializer(), + .DefaultInitializer(), jit: bool = True ): self.lse_mode = lse_mode @@ -418,13 +418,13 @@ def __call__( ) # Cancel dual variables for zero weights. - init_dual_a, init_dual_b = init_lib.remove_weight_potentials( - weights_a=ot_prob.a, - weights_b=ot_prob.b, - init_dual_a=init_dual_a, - init_dual_b=init_dual_b, - lse_mode=self.lse_mode + init_dual_a = jnp.where( + ot_prob.a > 0, init_dual_a, -jnp.inf if self.lse_mode else 0.0 ) + init_dual_b = jnp.where( + ot_prob.b > 0, init_dual_b, -jnp.inf if self.lse_mode else 0.0 + ) + run_fn = jax.jit(run) if self.jit else run return run_fn(ot_prob, self, (init_dual_a, init_dual_b)) @@ -703,7 +703,7 @@ def make( parallel_dual_updates: bool = False, use_danskin: bool = None, potential_initializer: init_lib.SinkhornInitializer = init_lib - .SinkhornInitializer(), + .DefaultInitializer(), jit: bool = False ) -> Sinkhorn: """For backward compatibility.""" diff --git a/tests/core/initializers_test.py b/tests/core/initializers_test.py index 6ab7145a8..2053dc0a9 100644 --- a/tests/core/initializers_test.py +++ b/tests/core/initializers_test.py @@ -16,7 +16,7 @@ import jax import jax.numpy as jnp import numpy as np -from absl.testing import absltest, parameterized +import pytest from ott.core import initializers as init_lib from ott.core import linear_problems @@ -24,11 +24,57 @@ from ott.geometry import geometry, pointcloud +def create_sorting_problem(rng, n, epsilon=0.01): + # definte ot problem + x_init = jnp.array([-1., 0, .22]) + y_init = jnp.array([0., 0, 1.1]) + x_rng, y_rng = jax.random.split(rng) + + x = jnp.concatenate([x_init, 10 + jnp.abs(jax.random.normal(x_rng, (n,)))]) + y = jnp.concatenate([y_init, 10 + jnp.abs(jax.random.normal(y_rng, (n,)))]) + + x = np.sort(x) + y = np.sort(y) + + n = len(x) + m = len(y) + a = np.ones(n) / n + b = np.ones(m) / m + + geom = pointcloud.PointCloud( + x.reshape(-1, 1), y.reshape(-1, 1), epsilon=epsilon + ) + ot_problem = linear_problems.LinearProblem(geom=geom, a=a, b=b) + + return ot_problem + + +def create_ot_problem(rng, n, m, d, epsilon=0.01): + # definte ot problem + x_rng, y_rng = jax.random.split(rng) + + mu_a = np.array([-1, 1]) * 5 + mu_b = np.array([0, 0]) + + x = jax.random.normal(x_rng, (n, d)) + mu_a + y = jax.random.normal(y_rng, (m, d)) + mu_b + + a = np.ones(n) / n + b = np.ones(m) / m + + x_jnp, y_jnp = jnp.array(x), jnp.array(y) + + geom = pointcloud.PointCloud(x_jnp, y_jnp, epsilon=epsilon) + + ot_problem = linear_problems.LinearProblem(geom=geom, a=a, b=b) + return ot_problem + + # define sinkhorn functions @jax.jit def run_sinkhorn_sort_init(x, y, a=None, b=None, epsilon=0.01, vector_min=True): geom = pointcloud.PointCloud(x, y, epsilon=epsilon) - sort_init = init_lib.SortingInitializer(vector_min=vector_min) + sort_init = init_lib.SortingInitializer(vectorized_update=vector_min) out = sinkhorn(geom, a=a, b=b, jit=True, potential_initializer=sort_init) return out @@ -53,77 +99,31 @@ def run_sinkhorn_gaus_init(x, y, a=None, b=None, epsilon=0.01): return out -class InitializerTest(parameterized.TestCase): - - def setUp(self): - super().setUp() - self.rng = jax.random.PRNGKey(0) - - def create_sorting_problem(self, n, epsilon=0.01): - # definte ot problem - x_init = jnp.array([-1., 0, .22]) - y_init = jnp.array([0., 0, 1.1]) - x_rng, y_rng = jax.random.split(self.rng) +class TestInitializers: - x = jnp.concatenate([x_init, 10 + jnp.abs(jax.random.normal(x_rng, - (n,)))]) * 5 - y = jnp.concatenate([y_init, 10 + jnp.abs(jax.random.normal(y_rng, - (n,)))]) * 5 + @pytest.fixture(autouse=True) + def initialize(self): + self.rng = jax.random.PRNGKey(42) - x = np.sort(x) - y = np.sort(y) - - n = len(x) - m = len(y) - a = np.ones(n) / n - b = np.ones(m) / m - - geom = pointcloud.PointCloud( - x.reshape(-1, 1), y.reshape(-1, 1), epsilon=epsilon - ) - ot_problem = linear_problems.LinearProblem(geom=geom, a=a, b=b) - - return ot_problem - - def create_ot_problem(self, n, m, d, epsilon=0.01): - # definte ot problem - x_rng, y_rng = jax.random.split(self.rng) - - mu_a = np.array([-1, 1]) * 5 - mu_b = np.array([0, 0]) - - x = jax.random.normal(x_rng, (n, d)) + mu_a - y = jax.random.normal(y_rng, (m, d)) + mu_b - - a = np.ones(n) / n - b = np.ones(m) / m - - x_jnp, y_jnp = jnp.array(x), jnp.array(y) - - geom = pointcloud.PointCloud(x_jnp, y_jnp, epsilon=epsilon) - - ot_problem = linear_problems.LinearProblem(geom=geom, a=a, b=b) - return ot_problem - - @parameterized.parameters([True], [False]) + @pytest.mark.fast.with_args("vector_min", [False, True]) def test_sorting_init(self, vector_min): """Tests sorting dual initializer.""" - n = 100 - epsilon = 0.001 + n = 500 + epsilon = 0.01 - ot_problem = self.create_sorting_problem(n=n, epsilon=epsilon) + ot_problem = create_sorting_problem(rng=self.rng, n=n, epsilon=epsilon) # run sinkhorn - sink_out = run_sinkhorn( + sink_out_base = run_sinkhorn( x=ot_problem.geom.x, y=ot_problem.geom.y, a=ot_problem.a, b=ot_problem.b, epsilon=epsilon ) - base_num_iter = jnp.sum(sink_out.errors > -1) + base_num_iter = jnp.sum(sink_out_base.errors > -1) - sink_out = run_sinkhorn_sort_init( + sink_out_init = run_sinkhorn_sort_init( x=ot_problem.geom.x, y=ot_problem.geom.y, a=ot_problem.a, @@ -131,11 +131,12 @@ def test_sorting_init(self, vector_min): epsilon=epsilon, vector_min=vector_min ) - sort_num_iter = jnp.sum(sink_out.errors > -1) + sort_num_iter = jnp.sum(sink_out_init.errors > -1) # check initializer is better or equal - self.assertGreaterEqual(base_num_iter, sort_num_iter) + assert base_num_iter > sort_num_iter + @pytest.mark.fast def test_default_initializer(self): """Tests default initializer""" n = 200 @@ -143,12 +144,12 @@ def test_default_initializer(self): d = 2 epsilon = 0.01 - ot_problem = self.create_ot_problem(n, m, d) + ot_problem = create_ot_problem(self.rng, n, m, d) - default_potential_a = init_lib._default_dual_a( + default_potential_a = init_lib.DefaultInitializer().init_dual_a( ot_problem=ot_problem, lse_mode=True ) - default_potential_b = init_lib._default_dual_b( + default_potential_b = init_lib.DefaultInitializer().init_dual_b( ot_problem=ot_problem, lse_mode=True ) @@ -175,6 +176,7 @@ def test_default_initializer(self): np.testing.assert_array_equal(jnp.zeros(n), init_potential_a) np.testing.assert_array_equal(jnp.zeros(m), init_potential_b) + @pytest.mark.fast def test_gaus_initializer(self): """Tests Gaussian initializer""" # definte ot problem @@ -183,7 +185,7 @@ def test_gaus_initializer(self): d = 2 epsilon = 0.01 - ot_problem = self.create_ot_problem(n, m, d) + ot_problem = create_ot_problem(self.rng, n, m, d) # run sinkhorn sink_out = run_sinkhorn( @@ -205,8 +207,4 @@ def test_gaus_initializer(self): gaus_num_iter = jnp.sum(sink_out.errors > -1) # check initializer is better - self.assertGreaterEqual(base_num_iter, gaus_num_iter) - - -if __name__ == '__main__': - absltest.main() + assert base_num_iter > gaus_num_iter diff --git a/tests/core/sinkhorn_test.py b/tests/core/sinkhorn_test.py index 912b75268..8f3805e5d 100644 --- a/tests/core/sinkhorn_test.py +++ b/tests/core/sinkhorn_test.py @@ -414,14 +414,11 @@ def test_restart(self, lse_mode: bool): default_a = jnp.ones_like(init_dual_a) default_b = jnp.ones_like(init_dual_b) - np.testing.assert_raises( - AssertionError, - lambda: np.testing.assert_allclose(default_a, init_dual_a) - ) - np.testing.assert_raises( - AssertionError, - lambda: np.testing.assert_allclose(default_b, init_dual_b) - ) + with pytest.raises(AssertionError): + np.testing.assert_allclose(default_a, init_dual_a) + + with pytest.raises(AssertionError): + np.testing.assert_allclose(default_b, init_dual_b) out_restarted = sinkhorn.sinkhorn( geom, From f6fdd5cf6f2a6ce050e0157d6b6e6e0c9d0f2006 Mon Sep 17 00:00:00 2001 From: James Thornton Date: Thu, 14 Jul 2022 16:59:05 -0700 Subject: [PATCH 36/46] fix docstring spaces --- ott/core/initializers.py | 43 +++++++++++++++++++++------------------- 1 file changed, 23 insertions(+), 20 deletions(-) diff --git a/ott/core/initializers.py b/ott/core/initializers.py index 139aa8db3..9c7740324 100644 --- a/ott/core/initializers.py +++ b/ott/core/initializers.py @@ -35,9 +35,7 @@ def init_dual_b( class DefaultInitializer(SinkhornInitializer): - """Default Initialization of Sinkhorn dual potentials/ primal scalings. - - """ + """Default Initialization of Sinkhorn dual potentials/ primal scalings.""" def init_dual_a( self, ot_problem: linear_problems.LinearProblem, lse_mode: bool @@ -127,13 +125,13 @@ class SortingInitializer(DefaultInitializer): DualSort algorithm from https://arxiv.org/abs/2206.07630, solve non-regularized OT problem via sorting, then compute potential through - iterated minimum on C-transform and use this potentials to initialize + iterated minimum on C-transform and use this potential to initialize regularized potential Args: - vectorized_update: Use vectorized inner loop if true. Defaults to True. - tol: DualSort convergence threshold. Defaults to 1e-2. - max_iter: Max DualSort steps. Defaults to 100. + vectorized_update: Use vectorized inner loop if true. + tolerance: DualSort convergence threshold. + max_iter: Max DualSort steps. """ def __init__( @@ -152,13 +150,14 @@ def __init__( ) def init_sorting_dual( - self, modified_cost: jnp.ndarray, f_potential: jnp.ndarray + self, modified_cost: jnp.ndarray, init_f: jnp.ndarray ) -> jnp.ndarray: """Run DualSort algorithm. Args: modified_cost: cost matrix minus diagonal column-wise. - f_potential: potential f, array of size n. + init_f: potential f, array of size n. This is the starting potential, + which is then updated to make the init potential, so an init of an init. Returns: potential f, array of size n. @@ -177,7 +176,7 @@ def cond_fn(state): it = 0 diff = self.tolerance + 1.0 - state = (f_potential, diff, it) + state = (init_f, diff, it) f_potential, _, it = jax.lax.while_loop( cond_fun=cond_fn, body_fun=body_fn, init_val=state @@ -202,20 +201,24 @@ def init_dual_a( Returns: potential/ scaling f_u, array of size n. """ - cost_matrix = ot_problem.geom.cost_matrix - modified_cost = cost_matrix - jnp.diag(cost_matrix)[None, :] + if ot_problem.geom.is_online: + # raise error/ warning? + return super().init_dual_a(ot_problem, lse_mode) + else: + cost_matrix = ot_problem.geom.cost_matrix + modified_cost = cost_matrix - jnp.diag(cost_matrix)[None, :] - n = cost_matrix.shape[0] - f_potential = jnp.zeros(n) if init_f is None else init_f + n = cost_matrix.shape[0] + init_f = jnp.zeros(n) if init_f is None else init_f - f_potential = self.init_sorting_dual(modified_cost, f_potential) - f_potential = f_potential - jnp.mean(f_potential) + f_potential = self.init_sorting_dual(modified_cost, init_f) + f_potential = f_potential - jnp.mean(f_potential) - f_u = f_potential if lse_mode else ot_problem.scaling_from_potential( - f_potential - ) + f_u = f_potential if lse_mode else ot_problem.scaling_from_potential( + f_potential + ) - return f_u + return f_u def _vectorized_update( From c189f18eb66ec5b355aca1200e4afa6c8ae529ec Mon Sep 17 00:00:00 2001 From: James Thornton Date: Fri, 15 Jul 2022 18:34:56 -0700 Subject: [PATCH 37/46] remove spaces and add bibtex --- docs/references.bib | 7 +++++++ ott/core/initializers.py | 11 ++++++----- tests/core/initializers_test.py | 13 ++++++------- 3 files changed, 19 insertions(+), 12 deletions(-) diff --git a/docs/references.bib b/docs/references.bib index edcaff805..2f55db864 100644 --- a/docs/references.bib +++ b/docs/references.bib @@ -27,3 +27,10 @@ @InProceedings{scetbon:21 pdf = {http://proceedings.mlr.press/v139/scetbon21a/scetbon21a.pdf}, url = {https://proceedings.mlr.press/v139/scetbon21a.html}, } + +@article{thornton2022rethinking:22, + title={Rethinking Initialization of the Sinkhorn Algorithm}, + author={Thornton, James and Cuturi, Marco}, + journal={arXiv preprint arXiv:2206.07630}, + year={2022} +} diff --git a/ott/core/initializers.py b/ott/core/initializers.py index 9c7740324..d48292909 100644 --- a/ott/core/initializers.py +++ b/ott/core/initializers.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Sinkhorn initializers.""" +from abc import ABC, abstractmethod from typing import Optional import jax @@ -21,13 +22,15 @@ from ott.geometry import pointcloud -class SinkhornInitializer: +class SinkhornInitializer(ABC): + @abstractmethod def init_dual_a( self, ot_problem: linear_problems.LinearProblem, lse_mode: bool ) -> jnp.ndarray: """Initialization for Sinkhorn potential/ scaling f_u.""" + @abstractmethod def init_dual_b( self, ot_problem: linear_problems.LinearProblem, lse_mode: bool ) -> jnp.ndarray: @@ -73,7 +76,7 @@ def init_dual_b( class GaussianInitializer(DefaultInitializer): """GaussianInitializer. - From https://arxiv.org/abs/2206.07630. + From :cite:`thornton2022rethinking:22`. Compute Gaussian approximations of each pointcloud, then compute closed from Kantorovich potential betwen Gaussian approximations using Brenier's theorem (adapt convex/ Brenier potential to Kantorovich). Use this Gaussian potential to @@ -123,7 +126,7 @@ def init_dual_a( class SortingInitializer(DefaultInitializer): """Sorting Init class. - DualSort algorithm from https://arxiv.org/abs/2206.07630, solve + DualSort algorithm from :cite:`thornton2022rethinking:22`, solve non-regularized OT problem via sorting, then compute potential through iterated minimum on C-transform and use this potential to initialize regularized potential @@ -140,9 +143,7 @@ def __init__( tolerance: float = 1e-2, max_iter: int = 100 ): - super().__init__() - self.tolerance = tolerance self.max_iter = max_iter self.update_fn = lambda f, mod_cost: jax.lax.cond( diff --git a/tests/core/initializers_test.py b/tests/core/initializers_test.py index 2053dc0a9..cdecfad1e 100644 --- a/tests/core/initializers_test.py +++ b/tests/core/initializers_test.py @@ -101,18 +101,15 @@ def run_sinkhorn_gaus_init(x, y, a=None, b=None, epsilon=0.01): class TestInitializers: - @pytest.fixture(autouse=True) - def initialize(self): - self.rng = jax.random.PRNGKey(42) - @pytest.mark.fast.with_args("vector_min", [False, True]) def test_sorting_init(self, vector_min): """Tests sorting dual initializer.""" n = 500 epsilon = 0.01 + rng = jax.random.PRNGKey(42) - ot_problem = create_sorting_problem(rng=self.rng, n=n, epsilon=epsilon) + ot_problem = create_sorting_problem(rng=rng, n=n, epsilon=epsilon) # run sinkhorn sink_out_base = run_sinkhorn( x=ot_problem.geom.x, @@ -143,8 +140,9 @@ def test_default_initializer(self): m = 200 d = 2 epsilon = 0.01 + rng = jax.random.PRNGKey(42) - ot_problem = create_ot_problem(self.rng, n, m, d) + ot_problem = create_ot_problem(rng, n, m, d) default_potential_a = init_lib.DefaultInitializer().init_dual_a( ot_problem=ot_problem, lse_mode=True @@ -184,8 +182,9 @@ def test_gaus_initializer(self): m = 200 d = 2 epsilon = 0.01 + rng = jax.random.PRNGKey(42) - ot_problem = create_ot_problem(self.rng, n, m, d) + ot_problem = create_ot_problem(rng, n, m, d) # run sinkhorn sink_out = run_sinkhorn( From 3855bb902d92490085da27f9a7755bbf62fbe4a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CJTT94=E2=80=9D?= <“jtthornton1994@gmail.com”> Date: Wed, 17 Aug 2022 14:02:43 +0200 Subject: [PATCH 38/46] add errors for non square cost matrix for sorting, online geoms for initializers, tests --- docs/references.bib | 477 ++++++++++++++++++++++++++++++++ ott/core/initializers.py | 63 ++--- tests/core/initializers_test.py | 78 ++++-- 3 files changed, 565 insertions(+), 53 deletions(-) diff --git a/docs/references.bib b/docs/references.bib index 2f55db864..8ee9bbc05 100644 --- a/docs/references.bib +++ b/docs/references.bib @@ -1,3 +1,51 @@ +@InProceedings{vayer:19, + title = {Optimal Transport for structured data with application on graphs}, + author = {Titouan, Vayer and Courty, Nicolas and Tavenard, Romain and Laetitia, Chapel and Flamary, R{\'e}mi}, + booktitle = {Proceedings of the 36th International Conference on Machine Learning}, + pages = {6275--6284}, + year = {2019}, + editor = {Chaudhuri, Kamalika and Salakhutdinov, Ruslan}, + volume = {97}, + series = {Proceedings of Machine Learning Research}, + month = {09--15 Jun}, + publisher = {PMLR}, + pdf = {http://proceedings.mlr.press/v97/titouan19a/titouan19a.pdf}, + url = {https://proceedings.mlr.press/v97/titouan19a.html}, +} + +@InProceedings{peyre:16, + title = {Gromov-Wasserstein Averaging of Kernel and Distance Matrices}, + author = {Peyré, Gabriel and Cuturi, Marco and Solomon, Justin}, + booktitle = {Proceedings of The 33rd International Conference on Machine Learning}, + pages = {2664--2672}, + year = {2016}, + editor = {Balcan, Maria Florina and Weinberger, Kilian Q.}, + volume = {48}, + series = {Proceedings of Machine Learning Research}, + address = {New York, New York, USA}, + month = {20--22 Jun}, + publisher = {PMLR}, + pdf = {http://proceedings.mlr.press/v48/peyre16.pdf}, + url = {https://proceedings.mlr.press/v48/peyre16.html}, +} + +@InProceedings{cuturi:14, + title = {Fast Computation of Wasserstein Barycenters}, + author = {Cuturi, Marco and Doucet, Arnaud}, + booktitle = {Proceedings of the 31st International Conference on Machine Learning}, + pages = {685--693}, + year = {2014}, + editor = {Xing, Eric P. and Jebara, Tony}, + volume = {32}, + number = {2}, + series = {Proceedings of Machine Learning Research}, + address = {Bejing, China}, + month = {22--24 Jun}, + publisher = {PMLR}, + pdf = {http://proceedings.mlr.press/v32/cuturi14.pdf}, + url = {https://proceedings.mlr.press/v32/cuturi14.html}, +} + @InProceedings{indyk:19, title = {Sample-Optimal Low-Rank Approximation of Distance Matrices}, author = {Indyk, Pitor and Vakilian, Ali and Wagner, Tal and Woodruff, David P}, @@ -28,6 +76,435 @@ @InProceedings{scetbon:21 url = {https://proceedings.mlr.press/v139/scetbon21a.html}, } +@Article{schiebinger:19, + author = {Schiebinger, Geoffrey and Shu, Jian and Tabaka, Marcin and Cleary, Brian and Subramanian, Vidya + and Solomon, Aryeh and Gould, Joshua and Liu, Siyan and Lin, Stacie and Berube, Peter and Lee, Lia and Chen, Jenny + and Brumbaugh, Justin and Rigollet, Philippe and Hochedlinger, Konrad and Jaenisch, Rudolf + and Regev, Aviv and Lander, Eric S.}, + title = {Optimal-Transport Analysis of Single-Cell Gene Expression Identifies Developmental Trajectories + in Reprogramming}, + journal = {Cell}, + year = {2019}, + month = {Feb}, + day = {07}, + publisher = {Elsevier}, + volume = {176}, + number = {4}, + pages = {928-943.e22}, + issn = {0092-8674}, + doi = {10.1016/j.cell.2019.01.006}, +} + +@Article{memoli:11, + author = "M{\'e}moli, Facundo", + title = "Gromov--Wasserstein Distances and the Metric Approach to Object Matching", + journal = "Foundations of Computational Mathematics", + year = "2011", + month = "Aug", + day = "01", + volume = "11", + number = "4", + pages = "417--487", + issn = "1615-3383", + doi = "10.1007/s10208-011-9093-5", + url = "https://doi.org/10.1007/s10208-011-9093-5" +} + +@InProceedings{scetbon:22, + title = {Linear-Time Gromov {W}asserstein Distances using Low Rank Couplings and Costs}, + author = {Scetbon, Meyer and Peyr{\'e}, Gabriel and Cuturi, Marco}, + booktitle = {Proceedings of the 39th International Conference on Machine Learning}, + pages = {19347--19365}, + year = {2022}, + editor = {Chaudhuri, Kamalika and Jegelka, Stefanie and Song, Le and Szepesvari, Csaba and Niu, Gang and Sabato, Sivan}, + volume = {162}, + series = {Proceedings of Machine Learning Research}, + month = {17--23 Jul}, + publisher = {PMLR}, + pdf = {https://proceedings.mlr.press/v162/scetbon22b/scetbon22b.pdf}, + url = {https://proceedings.mlr.press/v162/scetbon22b.html}, +} + +@Article{vayer:20, + author = {Vayer, Titouan and Chapel, Laetitia and Flamary, Remi and Tavenard, Romain and Courty, Nicolas}, + title = {Fused Gromov-Wasserstein Distance for Structured Objects}, + journal = {Algorithms}, + volume = {13}, + year = {2020}, + number = {9}, + article-numer = {212}, + url = {https://www.mdpi.com/1999-4893/13/9/212}, + issn = {1999-4893}, + doi = {10.3390/a13090212} +} + +@Article{demetci:20, + author = {Demetci, Pinar and Santorella, Rebecca and Sandstede, Bj{\"o}rn and Noble, William Stafford and Singh, Ritambhara}, + title = {Gromov-Wasserstein optimal transport to align single-cell multi-omics data}, + elocation-id = {2020.04.28.066787}, + year = {2020}, + doi = {10.1101/2020.04.28.066787}, + publisher = {Cold Spring Harbor Laboratory}, + URL = {https://www.biorxiv.org/content/early/2020/11/11/2020.04.28.066787}, + eprint = {https://www.biorxiv.org/content/early/2020/11/11/2020.04.28.066787.full.pdf}, + journal = {bioRxiv} +} + +@Article{chen:19, + author = "Chen, Song and Lake, Blue B. and Zhang, Kun", + title = "High-throughput sequencing of the transcriptome and chromatin accessibility in the same cell", + journal = "Nature Biotechnology", + year = "2019", + month = "Dec", + day = "01", + volume = "37", + number = "12", + pages = "1452--1457", + issn = "1546-1696", + doi = "10.1038/s41587-019-0290-0", + url = "https://doi.org/10.1038/s41587-019-0290-0" +} + +@Misc{richter-powell:21, + doi = {10.48550/ARXIV.2111.12187}, + url = {https://arxiv.org/abs/2111.12187}, + author = {Richter-Powell, Jack and Lorraine, Jonathan and Amos, Brandon}, + keywords = {Machine Learning (cs.LG), Machine Learning (stat.ML), FOS: Computer and information sciences, FOS: Computer and information sciences}, + title = {Input Convex Gradient Networks}, + publisher = {arXiv}, + year = {2021}, + copyright = {arXiv.org perpetual, non-exclusive license} +} + +@Misc{bunne:22, + doi = {10.48550/ARXIV.2206.14262}, + url = {https://arxiv.org/abs/2206.14262}, + author = {Bunne, Charlotte and Krause, Andreas and Cuturi, Marco}, + keywords = {Machine Learning (cs.LG), FOS: Computer and information sciences, FOS: Computer and information sciences}, + title = {Supervised Training of Conditional Monge Maps}, + publisher = {arXiv}, + year = {2022}, + copyright = {Creative Commons Attribution Non Commercial Share Alike 4.0 International} +} + +@Article{gelbrich:90, + author = {Gelbrich, Matthias}, + title = {On a Formula for the L2 Wasserstein Metric between Measures on Euclidean and Hilbert Spaces}, + journal = {Mathematische Nachrichten}, + volume = {147}, + number = {1}, + pages = {185-203}, + doi = {https://doi.org/10.1002/mana.19901470121}, + url = {https://onlinelibrary.wiley.com/doi/abs/10.1002/mana.19901470121}, + eprint = {https://onlinelibrary.wiley.com/doi/pdf/10.1002/mana.19901470121}, + year = {1990} +} + +@InProceedings{amos:17, + title = {Input Convex Neural Networks}, + author = {Brandon Amos and Lei Xu and J. Zico Kolter}, + booktitle = {Proceedings of the 34th International Conference on Machine Learning}, + pages = {146--155}, + year = {2017}, + editor = {Precup, Doina and Teh, Yee Whye}, + volume = {70}, + series = {Proceedings of Machine Learning Research}, + month = {06--11 Aug}, + publisher = {PMLR}, + pdf = {http://proceedings.mlr.press/v70/amos17b/amos17b.pdf}, + url = {https://proceedings.mlr.press/v70/amos17b.html}, +} + +@InProceedings{makkuva:20, + title = {Optimal transport mapping via input convex neural networks}, + author = {Makkuva, Ashok and Taghvaei, Amirhossein and Oh, Sewoong and Lee, Jason}, + booktitle = {Proceedings of the 37th International Conference on Machine Learning}, + pages = {6672--6681}, + year = {2020}, + editor = {III, Hal Daumé and Singh, Aarti}, + volume = {119}, + series = {Proceedings of Machine Learning Research}, + month = {13--18 Jul}, + publisher = {PMLR}, + pdf = {http://proceedings.mlr.press/v119/makkuva20a/makkuva20a.pdf}, + url = {https://proceedings.mlr.press/v119/makkuva20a.html}, +} + +@InProceedings{cuturi:19, + author = {Cuturi, Marco and Teboul, Olivier and Vert, Jean-Philippe}, + booktitle = {Advances in Neural Information Processing Systems}, + editor = {H. Wallach and H. Larochelle and A. Beygelzimer and F. d\textquotesingle Alch\'{e}-Buc and E. Fox and R. Garnett}, + pages = {}, + publisher = {Curran Associates, Inc.}, + title = {Differentiable Ranking and Sorting using Optimal Transport}, + url = {https://proceedings.neurips.cc/paper/2019/file/d8c24ca8f23c562a5600876ca2a550ce-Paper.pdf}, + volume = {32}, + year = {2019} +} + +@InProceedings{cuturi:20a, + title = {Supervised Quantile Normalization for Low Rank Matrix Factorization}, + author = {Cuturi, Marco and Teboul, Olivier and Niles-Weed, Jonathan and Vert, Jean-Philippe}, + booktitle = {Proceedings of the 37th International Conference on Machine Learning}, + pages = {2269--2279}, + year = {2020}, + editor = {III, Hal Daumé and Singh, Aarti}, + volume = {119}, + series = {Proceedings of Machine Learning Research}, + month = {13--18 Jul}, + publisher = {PMLR}, + pdf = {http://proceedings.mlr.press/v119/cuturi20a/cuturi20a.pdf}, + url = {https://proceedings.mlr.press/v119/cuturi20a.html}, +} + +@InProceedings{gramfort:15, + title = {Fast optimal transport averaging of neuroimaging data}, + author = {Gramfort, Alexandre and Peyr{\'e}, Gabriel and Cuturi, Marco}, + booktitle = {International Conference on Information Processing in Medical Imaging}, + pages = {261--272}, + year = {2015}, + organization = {Springer} +} + +@Article{benamou:15, + author = {Benamou, Jean-David and Carlier, Guillaume and Cuturi, Marco and Nenna, Luca and Peyr\'{e}, Gabriel}, + title = {Iterative Bregman Projections for Regularized Transportation Problems}, + journal = {SIAM Journal on Scientific Computing}, + volume = {37}, + number = {2}, + pages = {A1111-A1138}, + year = {2015}, + doi = {10.1137/141000439}, + URL = {https://doi.org/10.1137/141000439}, + eprint = {https://doi.org/10.1137/141000439} +} + +@InProceedings{cuturi:13, + author = {Cuturi, Marco}, + booktitle = {Advances in Neural Information Processing Systems}, + editor = {C.J. Burges and L. Bottou and M. Welling and Z. Ghahramani and K.Q. Weinberger}, + pages = {}, + publisher = {Curran Associates, Inc.}, + title = {Sinkhorn Distances: Lightspeed Computation of Optimal Transport}, + url = {https://proceedings.neurips.cc/paper/2013/file/af21d0c97db2e27e13572cbf59eb343d-Paper.pdf}, + volume = {26}, + year = {2013} +} + +@Article{peyre:19, + author = {Gabriel Peyré and Marco Cuturi}, + url = {http://dx.doi.org/10.1561/2200000073}, + year = {2019}, + volume = {11}, + journal = {Foundations and Trends® in Machine Learning}, + title = {Computational Optimal Transport: With Applications to Data Science}, + doi = {10.1561/2200000073}, + issn = {1935-8237}, + number = {5-6}, + pages = {355-607}, +} + +@Article{solomon:15, + author = {Solomon, Justin and de Goes, Fernando and Peyr\'{e}, Gabriel and Cuturi, Marco and Butscher, Adrian and + Nguyen, Andy and Du, Tao and Guibas, Leonidas}, + title = {Convolutional Wasserstein Distances: Efficient Optimal Transportation on Geometric Domains}, + year = {2015}, + issue_date = {August 2015}, + publisher = {Association for Computing Machinery}, + address = {New York, NY, USA}, + volume = {34}, + number = {4}, + issn = {0730-0301}, + url = {https://doi.org/10.1145/2766963}, + doi = {10.1145/2766963}, + journal = {ACM Trans. Graph.}, + month = {jul}, + articleno = {66}, + numpages = {11}, + keywords = {entropy, wasserstein distances, optimal transportation, displacement interpolation} +} + +@InProceedings{genevay:18, + title = {Learning Generative Models with Sinkhorn Divergences}, + author = {Genevay, Aude and Peyre, Gabriel and Cuturi, Marco}, + booktitle = {Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics}, + pages = {1608--1617}, + year = {2018}, + editor = {Storkey, Amos and Perez-Cruz, Fernando}, + volume = {84}, + series = {Proceedings of Machine Learning Research}, + month = {09--11 Apr}, + publisher = {PMLR}, + pdf = {http://proceedings.mlr.press/v84/genevay18a/genevay18a.pdf}, + url = {https://proceedings.mlr.press/v84/genevay18a.html}, +} + +@Misc{sejourne:19, + doi = {10.48550/ARXIV.1910.12958}, + url = {https://arxiv.org/abs/1910.12958}, + author = {Séjourné, Thibault and Feydy, Jean and Vialard, François-Xavier and Trouvé, Alain and Peyré, Gabriel}, + keywords = {Optimization and Control (math.OC), Machine Learning (cs.LG), Machine Learning (stat.ML), + FOS: Mathematics, FOS: Mathematics, FOS: Computer and information sciences, FOS: Computer and information sciences}, + title = {Sinkhorn Divergences for Unbalanced Optimal Transport}, + publisher = {arXiv}, + year = {2019}, + copyright = {arXiv.org perpetual, non-exclusive license} +} + +@Article{janati:20, + title = {Entropic optimal transport between unbalanced Gaussian measures has a closed form}, + author = {Janati, Hicham and Muzellec, Boris and Peyr{\'e}, Gabriel and Cuturi, Marco}, + journal = {Advances in neural information processing systems}, + volume = {33}, + pages = {10468--10479}, + year = {2020} +} + +@Article{chen:19a, + author = {Chen, Yongxin and Georgiou, Tryphon T. and Tannenbaum, Allen}, + journal = {IEEE Access}, + title = {Optimal Transport for Gaussian Mixture Models}, + year = {2019}, + volume = {7}, + number = {}, + pages = {6269-6278}, + doi = {10.1109/ACCESS.2018.2889838} +} + +@Article{delon:20, + author = {Delon, Julie and Desolneux, Agn\`{e}s}, + title = {A Wasserstein-Type Distance in the Space of Gaussian Mixture Models}, + journal = {SIAM Journal on Imaging Sciences}, + volume = {13}, + number = {2}, + pages = {936-970}, + year = {2020}, + doi = {10.1137/19M1301047}, + URL = {https://doi.org/10.1137/19M1301047}, + eprint = {https://doi.org/10.1137/19M1301047}, +} + +@InProceedings{janati:20a, + title = {Debiased {S}inkhorn barycenters}, + author = {Janati, Hicham and Cuturi, Marco and Gramfort, Alexandre}, + booktitle = {Proceedings of the 37th International Conference on Machine Learning}, + pages = {4692--4701}, + year = {2020}, + editor = {III, Hal Daumé and Singh, Aarti}, + volume = {119}, + series = {Proceedings of Machine Learning Research}, + month = {13--18 Jul}, + publisher = {PMLR}, + pdf = {http://proceedings.mlr.press/v119/janati20a/janati20a.pdf}, + url = {https://proceedings.mlr.press/v119/janati20a.html}, +} + +@Article{schmitz:18, + author = {Schmitz, Morgan A. and Heitz, Matthieu and Bonneel, Nicolas and Ngol\`{e}, Fred and Coeurjolly, David and + Cuturi, Marco and Peyr\'{e}, Gabriel and Starck, Jean-Luc}, + title = {Wasserstein Dictionary Learning: Optimal Transport-Based Unsupervised Nonlinear Dictionary Learning}, + journal = {SIAM Journal on Imaging Sciences}, + volume = {11}, + number = {1}, + pages = {643-678}, + year = {2018}, + doi = {10.1137/17M1140431}, + URL = {https://doi.org/10.1137/17M1140431}, + eprint = {https://doi.org/10.1137/17M1140431}, +} + +@Article{alvarez-esteban:16, + title = {A fixed-point approach to barycenters in Wasserstein space}, + journal = {Journal of Mathematical Analysis and Applications}, + volume = {441}, + number = {2}, + pages = {744-762}, + year = {2016}, + issn = {0022-247X}, + doi = {https://doi.org/10.1016/j.jmaa.2016.04.045}, + url = {https://www.sciencedirect.com/science/article/pii/S0022247X16300907}, + author = {Pedro C. Álvarez-Esteban and E. {del Barrio} and J.A. Cuesta-Albertos and C. Matrán}, + keywords = {Mass transportation problem, -Wasserstein distance, Wasserstein barycenter, Fréchet mean, + Fixed-point iteration, Location-scatter families}, +} + +@Article{lehmann:21, + author = "Lehmann, Tobias and von Renesse, Max-K. and Sambale, Alexander and Uschmajew, Andr{\'e}", + title = "A note on overrelaxation in the Sinkhorn algorithm", + journal = "Optimization Letters", + year = "2021", + month = "Dec", + day = "14", + issn = "1862-4480", + doi = "10.1007/s11590-021-01830-0", + url = "https://doi.org/10.1007/s11590-021-01830-0" +} + +@inproceedings{sejourne:21, + author = {Sejourne, Thibault and Vialard, Francois-Xavier and Peyr\'{e}, Gabriel}, + booktitle = {Advances in Neural Information Processing Systems}, + editor = {M. Ranzato and A. Beygelzimer and Y. Dauphin and P.S. Liang and J. Wortman Vaughan}, + pages = {8766--8779}, + publisher = {Curran Associates, Inc.}, + title = {The Unbalanced Gromov Wasserstein Distance: Conic Formulation and Relaxation}, + url = {https://proceedings.neurips.cc/paper/2021/file/4990974d150d0de5e6e15a1454fe6b0f-Paper.pdf}, + volume = {34}, + year = {2021} +} + +@inproceedings{chizat:20, + author = {Chizat, L\'{e}na\"{\i}c and Roussillon, Pierre and L\'{e}ger, Flavien and Vialard, + Fran\c{c}ois-Xavier and Peyr\'{e}, Gabriel}, + booktitle = {Advances in Neural Information Processing Systems}, + editor = {H. Larochelle and M. Ranzato and R. Hadsell and M.F. Balcan and H. Lin}, + pages = {2257--2269}, + publisher = {Curran Associates, Inc.}, + title = {Faster Wasserstein Distance Estimation with the Sinkhorn Divergence}, + url = {https://proceedings.neurips.cc/paper/2020/file/17f98ddf040204eda0af36a108cbdea4-Paper.pdf}, + volume = {33}, + year = {2020} +} + +@Article{higham:1997, + author = "Higham, Nicholas J.", + title = "Stable iterations for the matrix square root", + journal = "Numerical Algorithms", + year = "1997", + month = "Sep", + day = "01", + volume = "15", + number = "2", + pages = "227--242", + issn = "1572-9265", + doi = "10.1023/A:1019150005407", + url = "https://doi.org/10.1023/A:1019150005407" +} + +@Article{lloyd:82, + author={Lloyd, S.}, + journal={IEEE Transactions on Information Theory}, + title={Least squares quantization in PCM}, + year={1982}, + volume={28}, + number={2}, + pages={129-137}, + doi={10.1109/TIT.1982.1056489} +} + +@inproceedings{arthur:07, + author = {Arthur, David and Vassilvitskii, Sergei}, + title = {K-Means++: The Advantages of Careful Seeding}, + year = {2007}, + isbn = {9780898716245}, + publisher = {Society for Industrial and Applied Mathematics}, + address = {USA}, + booktitle = {Proceedings of the Eighteenth Annual ACM-SIAM Symposium on Discrete Algorithms}, + pages = {1027–1035}, + numpages = {9}, + location = {New Orleans, Louisiana}, + series = {SODA '07} +} + @article{thornton2022rethinking:22, title={Rethinking Initialization of the Sinkhorn Algorithm}, author={Thornton, James and Cuturi, Marco}, diff --git a/ott/core/initializers.py b/ott/core/initializers.py index d48292909..178cd6df0 100644 --- a/ott/core/initializers.py +++ b/ott/core/initializers.py @@ -104,23 +104,22 @@ def init_dual_a( # import Gaussian here due to circular imports from ott.tools.gaussian_mixture import gaussian - if not isinstance(ot_problem.geom, pointcloud.PointCloud): - # warning that init not applied - return super().init_dual_a(ot_problem, lse_mode) - else: - - x, y = ot_problem.geom.x, ot_problem.geom.y - a, b = ot_problem.a, ot_problem.b - - gaussian_a = gaussian.Gaussian.from_samples(x, weights=a) - gaussian_b = gaussian.Gaussian.from_samples(y, weights=b) - # Brenier potential for cost ||x-y||^2/2, multiply by two for ||x-y||^2 - f_potential = 2 * gaussian_a.f_potential(dest=gaussian_b, points=x) - f_potential = f_potential - jnp.mean(f_potential) - f_u = f_potential if lse_mode else ot_problem.scaling_from_potential( - f_potential - ) - return f_u + assert isinstance( + ot_problem.geom, pointcloud.PointCloud + ), "Gaussian initializer valid only for PointCloud geom" + + x, y = ot_problem.geom.x, ot_problem.geom.y + a, b = ot_problem.a, ot_problem.b + + gaussian_a = gaussian.Gaussian.from_samples(x, weights=a) + gaussian_b = gaussian.Gaussian.from_samples(y, weights=b) + # Brenier potential for cost ||x-y||^2/2, multiply by two for ||x-y||^2 + f_potential = 2 * gaussian_a.f_potential(dest=gaussian_b, points=x) + f_potential = f_potential - jnp.mean(f_potential) + f_u = f_potential if lse_mode else ot_problem.scaling_from_potential( + f_potential + ) + return f_u class SortingInitializer(DefaultInitializer): @@ -202,24 +201,26 @@ def init_dual_a( Returns: potential/ scaling f_u, array of size n. """ - if ot_problem.geom.is_online: - # raise error/ warning? - return super().init_dual_a(ot_problem, lse_mode) - else: - cost_matrix = ot_problem.geom.cost_matrix - modified_cost = cost_matrix - jnp.diag(cost_matrix)[None, :] + assert not ot_problem.geom.is_online, "Sorting initializer does not work for online geom" + # check for sorted x, y requires pointcloud and could slow initializer + cost_matrix = ot_problem.geom.cost_matrix - n = cost_matrix.shape[0] - init_f = jnp.zeros(n) if init_f is None else init_f + assert cost_matrix.shape[0] == cost_matrix.shape[ + 1], "Requires square cost matrix" - f_potential = self.init_sorting_dual(modified_cost, init_f) - f_potential = f_potential - jnp.mean(f_potential) + modified_cost = cost_matrix - jnp.diag(cost_matrix)[None, :] - f_u = f_potential if lse_mode else ot_problem.scaling_from_potential( - f_potential - ) + n = cost_matrix.shape[0] + init_f = jnp.zeros(n) if init_f is None else init_f + + f_potential = self.init_sorting_dual(modified_cost, init_f) + f_potential = f_potential - jnp.mean(f_potential) + + f_u = f_potential if lse_mode else ot_problem.scaling_from_potential( + f_potential + ) - return f_u + return f_u def _vectorized_update( diff --git a/tests/core/initializers_test.py b/tests/core/initializers_test.py index cdecfad1e..5084d1919 100644 --- a/tests/core/initializers_test.py +++ b/tests/core/initializers_test.py @@ -24,7 +24,7 @@ from ott.geometry import geometry, pointcloud -def create_sorting_problem(rng, n, epsilon=0.01): +def create_sorting_problem(rng, n, epsilon=0.01, online=False): # definte ot problem x_init = jnp.array([-1., 0, .22]) y_init = jnp.array([0., 0, 1.1]) @@ -41,15 +41,19 @@ def create_sorting_problem(rng, n, epsilon=0.01): a = np.ones(n) / n b = np.ones(m) / m + batch_size = 3 if online else None geom = pointcloud.PointCloud( - x.reshape(-1, 1), y.reshape(-1, 1), epsilon=epsilon + x.reshape(-1, 1), + y.reshape(-1, 1), + epsilon=epsilon, + batch_size=batch_size ) ot_problem = linear_problems.LinearProblem(geom=geom, a=a, b=b) return ot_problem -def create_ot_problem(rng, n, m, d, epsilon=0.01): +def create_ot_problem(rng, n, m, d, epsilon=0.01, online=False): # definte ot problem x_rng, y_rng = jax.random.split(rng) @@ -63,8 +67,10 @@ def create_ot_problem(rng, n, m, d, epsilon=0.01): b = np.ones(m) / m x_jnp, y_jnp = jnp.array(x), jnp.array(y) - - geom = pointcloud.PointCloud(x_jnp, y_jnp, epsilon=epsilon) + batch_size = 3 if online else None + geom = pointcloud.PointCloud( + x_jnp, y_jnp, epsilon=epsilon, batch_size=batch_size + ) ot_problem = linear_problems.LinearProblem(geom=geom, a=a, b=b) return ot_problem @@ -101,15 +107,16 @@ def run_sinkhorn_gaus_init(x, y, a=None, b=None, epsilon=0.01): class TestInitializers: - @pytest.mark.fast.with_args("vector_min", [False, True]) + @pytest.mark.fast.with_args("vector_min", [True, False]) def test_sorting_init(self, vector_min): """Tests sorting dual initializer.""" - n = 500 epsilon = 0.01 rng = jax.random.PRNGKey(42) - ot_problem = create_sorting_problem(rng=rng, n=n, epsilon=epsilon) + ot_problem = create_sorting_problem( + rng=rng, n=n, epsilon=epsilon, online=False + ) # run sinkhorn sink_out_base = run_sinkhorn( x=ot_problem.geom.x, @@ -133,6 +140,32 @@ def test_sorting_init(self, vector_min): # check initializer is better or equal assert base_num_iter > sort_num_iter + @pytest.mark.fast + def test_sorting_init_online(self): + n = 500 + epsilon = 0.01 + rng = jax.random.PRNGKey(42) + + ot_problem = create_sorting_problem( + rng=rng, n=n, epsilon=epsilon, online=True + ) + sort_init = init_lib.SortingInitializer(vectorized_update=True) + with pytest.raises(AssertionError): + sort_init.init_dual_a(ot_problem=ot_problem, lse_mode=True) + + @pytest.mark.fast + def test_sorting_init_square_cost(self): + n = 100 + m = 150 + d = 1 + epsilon = 0.01 + rng = jax.random.PRNGKey(42) + + ot_problem = create_ot_problem(rng, n, m, d, epsilon=epsilon, online=False) + sort_init = init_lib.SortingInitializer(vectorized_update=True) + with pytest.raises(AssertionError): + sort_init.init_dual_a(ot_problem=ot_problem, lse_mode=True) + @pytest.mark.fast def test_default_initializer(self): """Tests default initializer""" @@ -142,7 +175,7 @@ def test_default_initializer(self): epsilon = 0.01 rng = jax.random.PRNGKey(42) - ot_problem = create_ot_problem(rng, n, m, d) + ot_problem = create_ot_problem(rng, n, m, d, epsilon=epsilon, online=False) default_potential_a = init_lib.DefaultInitializer().init_dual_a( ot_problem=ot_problem, lse_mode=True @@ -155,8 +188,16 @@ def test_default_initializer(self): np.testing.assert_array_equal(jnp.zeros(n), default_potential_a) np.testing.assert_array_equal(jnp.zeros(m), default_potential_b) - # check gausian init returns 0 for non point cloud geometry - # init initializer + @pytest.mark.fast + def test_gaus_pointcloud_geom(self): + n = 200 + m = 200 + d = 2 + epsilon = 0.01 + rng = jax.random.PRNGKey(42) + + ot_problem = create_ot_problem(rng, n, m, d, epsilon=epsilon, online=False) + gaus_init = init_lib.GaussianInitializer() new_geom = geometry.Geometry( cost_matrix=ot_problem.geom.cost_matrix, epsilon=epsilon @@ -164,17 +205,11 @@ def test_default_initializer(self): ot_problem = linear_problems.LinearProblem( geom=new_geom, a=ot_problem.a, b=ot_problem.b ) - init_potential_a = gaus_init.init_dual_a( - ot_problem=ot_problem, lse_mode=True - ) - init_potential_b = gaus_init.init_dual_b( - ot_problem=ot_problem, lse_mode=True - ) - np.testing.assert_array_equal(jnp.zeros(n), init_potential_a) - np.testing.assert_array_equal(jnp.zeros(m), init_potential_b) + with pytest.raises(AssertionError): + gaus_init.init_dual_a(ot_problem=ot_problem, lse_mode=True) - @pytest.mark.fast + @pytest.mark.fast.with_args() def test_gaus_initializer(self): """Tests Gaussian initializer""" # definte ot problem @@ -184,7 +219,7 @@ def test_gaus_initializer(self): epsilon = 0.01 rng = jax.random.PRNGKey(42) - ot_problem = create_ot_problem(rng, n, m, d) + ot_problem = create_ot_problem(rng, n, m, d, epsilon=epsilon, online=False) # run sinkhorn sink_out = run_sinkhorn( @@ -195,7 +230,6 @@ def test_gaus_initializer(self): epsilon=epsilon ) base_num_iter = jnp.sum(sink_out.errors > -1) - sink_out = run_sinkhorn_gaus_init( x=ot_problem.geom.x, y=ot_problem.geom.y, From 49e5b4f73d78d0cfd69eeb9de1083d2e4fb96003 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CJTT94=E2=80=9D?= <“jtthornton1994@gmail.com”> Date: Wed, 17 Aug 2022 14:11:58 +0200 Subject: [PATCH 39/46] merge fix lint --- docs/references.bib | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/references.bib b/docs/references.bib index e2bdedee6..8ee9bbc05 100644 --- a/docs/references.bib +++ b/docs/references.bib @@ -511,4 +511,3 @@ @article{thornton2022rethinking:22 journal={arXiv preprint arXiv:2206.07630}, year={2022} } - From ce4d14c3c5f451c9f448ee53c8549790931f4a73 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CJTT94=E2=80=9D?= <“jtthornton1994@gmail.com”> Date: Wed, 17 Aug 2022 17:03:40 +0200 Subject: [PATCH 40/46] merge fix lint --- ott/core/initializers.py | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/ott/core/initializers.py b/ott/core/initializers.py index 178cd6df0..50641ed08 100644 --- a/ott/core/initializers.py +++ b/ott/core/initializers.py @@ -28,29 +28,29 @@ class SinkhornInitializer(ABC): def init_dual_a( self, ot_problem: linear_problems.LinearProblem, lse_mode: bool ) -> jnp.ndarray: - """Initialization for Sinkhorn potential/ scaling f_u.""" + """Initialization for Sinkhorn potential/scaling f_u.""" @abstractmethod def init_dual_b( self, ot_problem: linear_problems.LinearProblem, lse_mode: bool ) -> jnp.ndarray: - """Initialization for Sinkhorn potential/ scaling g_v.""" + """Initialization for Sinkhorn potential/scaling g_v.""" class DefaultInitializer(SinkhornInitializer): - """Default Initialization of Sinkhorn dual potentials/ primal scalings.""" + """Default Initialization of Sinkhorn dual potentials/primal scalings.""" def init_dual_a( self, ot_problem: linear_problems.LinearProblem, lse_mode: bool ) -> jnp.ndarray: - """Initialization for Sinkhorn potential/ scaling f_u. + """Initialization for Sinkhorn potential/scaling f_u. Args: ot_problem: OT problem between discrete distributions of size n and m. lse_mode: Return potential if true, scaling if false. Returns: - potential/ scaling, array of size n + potential/scaling, array of size n """ a = ot_problem.a init_dual_a = jnp.zeros_like(a) if lse_mode else jnp.ones_like(a) @@ -59,14 +59,14 @@ def init_dual_a( def init_dual_b( self, ot_problem: linear_problems.LinearProblem, lse_mode: bool ) -> jnp.ndarray: - """Initialization for Sinkhorn potential/ scaling g_v. + """Initialization for Sinkhorn potential/scaling g_v. Args: ot_problem: OT problem between discrete distributions of size n and m. lse_mode: Return potential if true, scaling if false. Returns: - potential/ scaling, array of size m + potential/scaling, array of size m """ b = ot_problem.b init_dual_b = jnp.zeros_like(b) if lse_mode else jnp.ones_like(b) @@ -79,14 +79,11 @@ class GaussianInitializer(DefaultInitializer): From :cite:`thornton2022rethinking:22`. Compute Gaussian approximations of each pointcloud, then compute closed from Kantorovich potential betwen Gaussian approximations using Brenier's theorem - (adapt convex/ Brenier potential to Kantorovich). Use this Gaussian potential to - initialize Sinkhorn potentials/ scalings. + (adapt convex/Brenier potential to Kantorovich). Use this Gaussian potential to + initialize Sinkhorn potentials/scalings. """ - def __init__(self): - super().__init__() - def init_dual_a( self, ot_problem: linear_problems.LinearProblem, @@ -99,7 +96,7 @@ def init_dual_a( lse_mode: Return potential if true, scaling if false. Returns: - potential/ scaling f_u, array of size n. + potential/scaling f_u, array of size n. """ # import Gaussian here due to circular imports from ott.tools.gaussian_mixture import gaussian @@ -199,7 +196,7 @@ def init_dual_a( which is then updated to make the init potential, so an init of an init. Returns: - potential/ scaling f_u, array of size n. + potential/scaling f_u, array of size n. """ assert not ot_problem.geom.is_online, "Sorting initializer does not work for online geom" # check for sorted x, y requires pointcloud and could slow initializer From 1f93053118c9f96dfedb9429c34a2aa333294cf2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CJTT94=E2=80=9D?= <“jtthornton1994@gmail.com”> Date: Wed, 17 Aug 2022 17:30:14 +0200 Subject: [PATCH 41/46] add initializers as pytees --- ott/core/initializers.py | 27 ++++++++++++++++++++++++--- tests/core/initializers_test.py | 16 ++++++++++++++++ 2 files changed, 40 insertions(+), 3 deletions(-) diff --git a/ott/core/initializers.py b/ott/core/initializers.py index 50641ed08..9aa0a3cd0 100644 --- a/ott/core/initializers.py +++ b/ott/core/initializers.py @@ -13,7 +13,7 @@ # limitations under the License. """Sinkhorn initializers.""" from abc import ABC, abstractmethod -from typing import Optional +from typing import Any, Dict, Optional, Sequence, Tuple import jax import jax.numpy as jnp @@ -22,8 +22,18 @@ from ott.geometry import pointcloud +@jax.tree_util.register_pytree_node_class class SinkhornInitializer(ABC): + def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: + return ([], {}) + + @classmethod + def tree_unflatten( + cls, aux_data: Dict[str, Any], children: Sequence[Any] + ) -> "SinkhornInitializer": + return cls(*children, **aux_data) + @abstractmethod def init_dual_a( self, ot_problem: linear_problems.LinearProblem, lse_mode: bool @@ -37,6 +47,7 @@ def init_dual_b( """Initialization for Sinkhorn potential/scaling g_v.""" +@jax.tree_util.register_pytree_node_class class DefaultInitializer(SinkhornInitializer): """Default Initialization of Sinkhorn dual potentials/primal scalings.""" @@ -73,6 +84,7 @@ def init_dual_b( return init_dual_b +@jax.tree_util.register_pytree_node_class class GaussianInitializer(DefaultInitializer): """GaussianInitializer. @@ -113,12 +125,13 @@ def init_dual_a( # Brenier potential for cost ||x-y||^2/2, multiply by two for ||x-y||^2 f_potential = 2 * gaussian_a.f_potential(dest=gaussian_b, points=x) f_potential = f_potential - jnp.mean(f_potential) - f_u = f_potential if lse_mode else ot_problem.scaling_from_potential( + f_u = f_potential if lse_mode else ot_problem.geom.scaling_from_potential( f_potential ) return f_u +@jax.tree_util.register_pytree_node_class class SortingInitializer(DefaultInitializer): """Sorting Init class. @@ -142,10 +155,18 @@ def __init__( super().__init__() self.tolerance = tolerance self.max_iter = max_iter + self.vectorized_update = vectorized_update self.update_fn = lambda f, mod_cost: jax.lax.cond( vectorized_update, _vectorized_update, _coordinate_update, f, mod_cost ) + def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: + return ([], { + 'tolerance': self.tolerance, + 'max_iter': self.max_iter, + 'vectorized_update': self.vectorized_update + }) + def init_sorting_dual( self, modified_cost: jnp.ndarray, init_f: jnp.ndarray ) -> jnp.ndarray: @@ -213,7 +234,7 @@ def init_dual_a( f_potential = self.init_sorting_dual(modified_cost, init_f) f_potential = f_potential - jnp.mean(f_potential) - f_u = f_potential if lse_mode else ot_problem.scaling_from_potential( + f_u = f_potential if lse_mode else ot_problem.geom.scaling_from_potential( f_potential ) diff --git a/tests/core/initializers_test.py b/tests/core/initializers_test.py index 5084d1919..d6788c814 100644 --- a/tests/core/initializers_test.py +++ b/tests/core/initializers_test.py @@ -107,6 +107,22 @@ def run_sinkhorn_gaus_init(x, y, a=None, b=None, epsilon=0.01): class TestInitializers: + @pytest.mark.fast + def test_init_pytree(self): + + @jax.jit + def init_sort(): + init = init_lib.SortingInitializer() + return init + + @jax.jit + def init_gaus(): + init = init_lib.GaussianInitializer() + return init + + init_gaus() + init_sort() + @pytest.mark.fast.with_args("vector_min", [True, False]) def test_sorting_init(self, vector_min): """Tests sorting dual initializer.""" From f6659114ba50671c7b5a93256b79a2b8983783df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CJTT94=E2=80=9D?= <“jtthornton1994@gmail.com”> Date: Wed, 17 Aug 2022 18:16:27 +0200 Subject: [PATCH 42/46] add init scaling tests --- tests/core/initializers_test.py | 67 +++++++++++++++++++++------------ 1 file changed, 43 insertions(+), 24 deletions(-) diff --git a/tests/core/initializers_test.py b/tests/core/initializers_test.py index d6788c814..5f1f7f042 100644 --- a/tests/core/initializers_test.py +++ b/tests/core/initializers_test.py @@ -13,6 +13,8 @@ # Lint as: python3 """Tests for the Gromov Wasserstein.""" +from functools import partial + import jax import jax.numpy as jnp import numpy as np @@ -77,30 +79,40 @@ def create_ot_problem(rng, n, m, d, epsilon=0.01, online=False): # define sinkhorn functions -@jax.jit -def run_sinkhorn_sort_init(x, y, a=None, b=None, epsilon=0.01, vector_min=True): +@partial(jax.jit, static_argnames=['lse_mode']) +def run_sinkhorn_sort_init( + x, y, a=None, b=None, epsilon=0.01, vector_min=True, lse_mode=True +): geom = pointcloud.PointCloud(x, y, epsilon=epsilon) sort_init = init_lib.SortingInitializer(vectorized_update=vector_min) - out = sinkhorn(geom, a=a, b=b, jit=True, potential_initializer=sort_init) + out = sinkhorn( + geom, + a=a, + b=b, + jit=True, + potential_initializer=sort_init, + lse_mode=lse_mode + ) return out -@jax.jit -def run_sinkhorn(x, y, a=None, b=None, epsilon=0.01): +@partial(jax.jit, static_argnames=['lse_mode']) +def run_sinkhorn(x, y, a=None, b=None, epsilon=0.01, lse_mode=True): geom = pointcloud.PointCloud(x, y, epsilon=epsilon) - out = sinkhorn(geom, a=a, b=b, jit=True) + out = sinkhorn(geom, a=a, b=b, jit=True, lse_mode=lse_mode) return out -@jax.jit -def run_sinkhorn_gaus_init(x, y, a=None, b=None, epsilon=0.01): +@partial(jax.jit, static_argnames=['lse_mode']) +def run_sinkhorn_gaus_init(x, y, a=None, b=None, epsilon=0.01, lse_mode=True): geom = pointcloud.PointCloud(x, y, epsilon=epsilon) out = sinkhorn( geom, a=a, b=b, jit=True, - potential_initializer=init_lib.GaussianInitializer() + potential_initializer=init_lib.GaussianInitializer(), + lse_mode=lse_mode ) return out @@ -123,8 +135,10 @@ def init_gaus(): init_gaus() init_sort() - @pytest.mark.fast.with_args("vector_min", [True, False]) - def test_sorting_init(self, vector_min): + @pytest.mark.fast.with_args( + "vector_min, lse_mode", [(True, True), (True, False), (False, True)] + ) + def test_sorting_init(self, vector_min, lse_mode): """Tests sorting dual initializer.""" n = 500 epsilon = 0.01 @@ -149,18 +163,20 @@ def test_sorting_init(self, vector_min): a=ot_problem.a, b=ot_problem.b, epsilon=epsilon, - vector_min=vector_min + vector_min=vector_min, + lse_mode=lse_mode ) sort_num_iter = jnp.sum(sink_out_init.errors > -1) # check initializer is better or equal - assert base_num_iter > sort_num_iter + if lse_mode: + assert base_num_iter >= sort_num_iter @pytest.mark.fast def test_sorting_init_online(self): - n = 500 + n = 100 epsilon = 0.01 - rng = jax.random.PRNGKey(42) + rng = jax.random.PRNGKey(0) ot_problem = create_sorting_problem( rng=rng, n=n, epsilon=epsilon, online=True @@ -175,7 +191,7 @@ def test_sorting_init_square_cost(self): m = 150 d = 1 epsilon = 0.01 - rng = jax.random.PRNGKey(42) + rng = jax.random.PRNGKey(0) ot_problem = create_ot_problem(rng, n, m, d, epsilon=epsilon, online=False) sort_init = init_lib.SortingInitializer(vectorized_update=True) @@ -189,7 +205,7 @@ def test_default_initializer(self): m = 200 d = 2 epsilon = 0.01 - rng = jax.random.PRNGKey(42) + rng = jax.random.PRNGKey(0) ot_problem = create_ot_problem(rng, n, m, d, epsilon=epsilon, online=False) @@ -210,7 +226,7 @@ def test_gaus_pointcloud_geom(self): m = 200 d = 2 epsilon = 0.01 - rng = jax.random.PRNGKey(42) + rng = jax.random.PRNGKey(0) ot_problem = create_ot_problem(rng, n, m, d, epsilon=epsilon, online=False) @@ -225,15 +241,15 @@ def test_gaus_pointcloud_geom(self): with pytest.raises(AssertionError): gaus_init.init_dual_a(ot_problem=ot_problem, lse_mode=True) - @pytest.mark.fast.with_args() - def test_gaus_initializer(self): + @pytest.mark.fast.with_args('lse_mode', [True, False]) + def test_gaus_initializer(self, lse_mode): """Tests Gaussian initializer""" # definte ot problem n = 200 m = 200 d = 2 epsilon = 0.01 - rng = jax.random.PRNGKey(42) + rng = jax.random.PRNGKey(0) ot_problem = create_ot_problem(rng, n, m, d, epsilon=epsilon, online=False) @@ -243,7 +259,8 @@ def test_gaus_initializer(self): y=ot_problem.geom.y, a=ot_problem.a, b=ot_problem.b, - epsilon=epsilon + epsilon=epsilon, + lse_mode=lse_mode ) base_num_iter = jnp.sum(sink_out.errors > -1) sink_out = run_sinkhorn_gaus_init( @@ -251,9 +268,11 @@ def test_gaus_initializer(self): y=ot_problem.geom.y, a=ot_problem.a, b=ot_problem.b, - epsilon=epsilon + epsilon=epsilon, + lse_mode=lse_mode ) gaus_num_iter = jnp.sum(sink_out.errors > -1) # check initializer is better - assert base_num_iter > gaus_num_iter + if lse_mode: + assert base_num_iter >= gaus_num_iter From f4d4c1ef63087afadf272cfa170823e44ef72d59 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CJTT94=E2=80=9D?= <“jtthornton1994@gmail.com”> Date: Wed, 17 Aug 2022 19:21:41 +0200 Subject: [PATCH 43/46] add init scaling tests --- ott/core/initializers.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/ott/core/initializers.py b/ott/core/initializers.py index 9aa0a3cd0..d1bd834ca 100644 --- a/ott/core/initializers.py +++ b/ott/core/initializers.py @@ -156,9 +156,8 @@ def __init__( self.tolerance = tolerance self.max_iter = max_iter self.vectorized_update = vectorized_update - self.update_fn = lambda f, mod_cost: jax.lax.cond( - vectorized_update, _vectorized_update, _coordinate_update, f, mod_cost - ) + self.update_fn = _vectorized_update if self.vectorized_update else _coordinate_update + def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: return ([], { From 196de5f3d1eebf1cec0179756b177f4ab40698f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CJTT94=E2=80=9D?= <“jtthornton1994@gmail.com”> Date: Wed, 17 Aug 2022 19:23:31 +0200 Subject: [PATCH 44/46] simplify vector update flag in sorting initializer --- ott/core/initializers.py | 1 - tests/core/initializers_test.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/ott/core/initializers.py b/ott/core/initializers.py index d1bd834ca..cd58859de 100644 --- a/ott/core/initializers.py +++ b/ott/core/initializers.py @@ -157,7 +157,6 @@ def __init__( self.max_iter = max_iter self.vectorized_update = vectorized_update self.update_fn = _vectorized_update if self.vectorized_update else _coordinate_update - def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: return ([], { diff --git a/tests/core/initializers_test.py b/tests/core/initializers_test.py index 5f1f7f042..872be5325 100644 --- a/tests/core/initializers_test.py +++ b/tests/core/initializers_test.py @@ -79,7 +79,7 @@ def create_ot_problem(rng, n, m, d, epsilon=0.01, online=False): # define sinkhorn functions -@partial(jax.jit, static_argnames=['lse_mode']) +@partial(jax.jit, static_argnames=['lse_mode', 'vector_min']) def run_sinkhorn_sort_init( x, y, a=None, b=None, epsilon=0.01, vector_min=True, lse_mode=True ): From 3dcb736b6b454e368225b842f10dd04f973b5fd5 Mon Sep 17 00:00:00 2001 From: Michal Klein Date: Thu, 18 Aug 2022 00:00:32 +0200 Subject: [PATCH 45/46] Fix documentation rendering --- docs/core.rst | 2 +- ott/core/initializers.py | 23 ++++++++++++----------- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/docs/core.rst b/docs/core.rst index 35592f1e6..2acf61183 100644 --- a/docs/core.rst +++ b/docs/core.rst @@ -39,7 +39,7 @@ Sinkhorn Dual Initializers .. autosummary:: :toctree: _autosummary - initializers.SinkhornInitializer + initializers.DefaultInitializer initializers.GaussianInitializer initializers.SortingInitializer diff --git a/ott/core/initializers.py b/ott/core/initializers.py index cd58859de..4afee9400 100644 --- a/ott/core/initializers.py +++ b/ott/core/initializers.py @@ -24,9 +24,10 @@ @jax.tree_util.register_pytree_node_class class SinkhornInitializer(ABC): + """Base class for Sinkhorn initializers.""" def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: - return ([], {}) + return [], {} @classmethod def tree_unflatten( @@ -91,8 +92,8 @@ class GaussianInitializer(DefaultInitializer): From :cite:`thornton2022rethinking:22`. Compute Gaussian approximations of each pointcloud, then compute closed from Kantorovich potential betwen Gaussian approximations using Brenier's theorem - (adapt convex/Brenier potential to Kantorovich). Use this Gaussian potential to - initialize Sinkhorn potentials/scalings. + (adapt convex/Brenier potential to Kantorovich). Use this Gaussian potential + to initialize Sinkhorn potentials/scalings. """ @@ -138,7 +139,7 @@ class SortingInitializer(DefaultInitializer): DualSort algorithm from :cite:`thornton2022rethinking:22`, solve non-regularized OT problem via sorting, then compute potential through iterated minimum on C-transform and use this potential to initialize - regularized potential + regularized potential. Args: vectorized_update: Use vectorized inner loop if true. @@ -165,15 +166,15 @@ def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: 'vectorized_update': self.vectorized_update }) - def init_sorting_dual( + def _init_sorting_dual( self, modified_cost: jnp.ndarray, init_f: jnp.ndarray ) -> jnp.ndarray: """Run DualSort algorithm. Args: - modified_cost: cost matrix minus diagonal column-wise. + modified_cost: cost matrix minus diagonal column-wise. init_f: potential f, array of size n. This is the starting potential, - which is then updated to make the init potential, so an init of an init. + which is then updated to make the init potential, so an init of an init. Returns: potential f, array of size n. @@ -212,24 +213,24 @@ def init_dual_a( ot_problem: OT problem. lse_mode: Return potential if true, scaling if false. init_f: potential f, array of size n. This is the starting potential, - which is then updated to make the init potential, so an init of an init. + which is then updated to make the init potential, so an init of an init. Returns: potential/scaling f_u, array of size n. """ - assert not ot_problem.geom.is_online, "Sorting initializer does not work for online geom" + assert not ot_problem.geom.is_online, "Sorting initializer does not work for online geometry." # check for sorted x, y requires pointcloud and could slow initializer cost_matrix = ot_problem.geom.cost_matrix assert cost_matrix.shape[0] == cost_matrix.shape[ - 1], "Requires square cost matrix" + 1], "Requires square cost matrix." modified_cost = cost_matrix - jnp.diag(cost_matrix)[None, :] n = cost_matrix.shape[0] init_f = jnp.zeros(n) if init_f is None else init_f - f_potential = self.init_sorting_dual(modified_cost, init_f) + f_potential = self._init_sorting_dual(modified_cost, init_f) f_potential = f_potential - jnp.mean(f_potential) f_u = f_potential if lse_mode else ot_problem.geom.scaling_from_potential( From 69f30504d541b4f4c1f969b9bb4667518c6f1138 Mon Sep 17 00:00:00 2001 From: Michal Klein Date: Thu, 18 Aug 2022 00:18:59 +0200 Subject: [PATCH 46/46] [ci skip] Fix typo in docs, use fixture in tests --- ott/core/initializers.py | 17 ++++----- tests/core/initializers_test.py | 66 ++++++++++++++------------------- 2 files changed, 35 insertions(+), 48 deletions(-) diff --git a/ott/core/initializers.py b/ott/core/initializers.py index 4afee9400..05e700a3e 100644 --- a/ott/core/initializers.py +++ b/ott/core/initializers.py @@ -50,7 +50,7 @@ def init_dual_b( @jax.tree_util.register_pytree_node_class class DefaultInitializer(SinkhornInitializer): - """Default Initialization of Sinkhorn dual potentials/primal scalings.""" + """Default initialization of Sinkhorn dual potentials/primal scalings.""" def init_dual_a( self, ot_problem: linear_problems.LinearProblem, lse_mode: bool @@ -87,14 +87,13 @@ def init_dual_b( @jax.tree_util.register_pytree_node_class class GaussianInitializer(DefaultInitializer): - """GaussianInitializer. + """Gaussian initializer. From :cite:`thornton2022rethinking:22`. - Compute Gaussian approximations of each pointcloud, then compute closed from - Kantorovich potential betwen Gaussian approximations using Brenier's theorem + Compute Gaussian approximations of each point cloud, then compute closed from + Kantorovich potential between Gaussian approximations using Brenier's theorem (adapt convex/Brenier potential to Kantorovich). Use this Gaussian potential to initialize Sinkhorn potentials/scalings. - """ def init_dual_a( @@ -116,7 +115,7 @@ def init_dual_a( assert isinstance( ot_problem.geom, pointcloud.PointCloud - ), "Gaussian initializer valid only for PointCloud geom" + ), "Gaussian initializer valid only for point clouds." x, y = ot_problem.geom.x, ot_problem.geom.y a, b = ot_problem.a, ot_problem.b @@ -134,7 +133,7 @@ def init_dual_a( @jax.tree_util.register_pytree_node_class class SortingInitializer(DefaultInitializer): - """Sorting Init class. + """Sorting initializer. DualSort algorithm from :cite:`thornton2022rethinking:22`, solve non-regularized OT problem via sorting, then compute potential through @@ -207,7 +206,7 @@ def init_dual_a( lse_mode: bool, init_f: Optional[jnp.ndarray] = None, ) -> jnp.ndarray: - """Apply DualSort algo. + """Apply DualSort algorithm. Args: ot_problem: OT problem. @@ -219,7 +218,7 @@ def init_dual_a( potential/scaling f_u, array of size n. """ assert not ot_problem.geom.is_online, "Sorting initializer does not work for online geometry." - # check for sorted x, y requires pointcloud and could slow initializer + # check for sorted x, y requires point cloud and could slow initializer cost_matrix = ot_problem.geom.cost_matrix assert cost_matrix.shape[0] == cost_matrix.shape[ diff --git a/tests/core/initializers_test.py b/tests/core/initializers_test.py index 872be5325..340edc7b8 100644 --- a/tests/core/initializers_test.py +++ b/tests/core/initializers_test.py @@ -11,7 +11,7 @@ # limitations under the License. # Lint as: python3 -"""Tests for the Gromov Wasserstein.""" +"""Tests for Sinkhorn initializers.""" from functools import partial @@ -27,7 +27,7 @@ def create_sorting_problem(rng, n, epsilon=0.01, online=False): - # definte ot problem + # define ot problem x_init = jnp.array([-1., 0, .22]) y_init = jnp.array([0., 0, 1.1]) x_rng, y_rng = jax.random.split(rng) @@ -35,13 +35,13 @@ def create_sorting_problem(rng, n, epsilon=0.01, online=False): x = jnp.concatenate([x_init, 10 + jnp.abs(jax.random.normal(x_rng, (n,)))]) y = jnp.concatenate([y_init, 10 + jnp.abs(jax.random.normal(y_rng, (n,)))]) - x = np.sort(x) - y = np.sort(y) + x = jnp.sort(x) + y = jnp.sort(y) n = len(x) m = len(y) - a = np.ones(n) / n - b = np.ones(m) / m + a = jnp.ones(n) / n + b = jnp.ones(m) / m batch_size = 3 if online else None geom = pointcloud.PointCloud( @@ -56,23 +56,20 @@ def create_sorting_problem(rng, n, epsilon=0.01, online=False): def create_ot_problem(rng, n, m, d, epsilon=0.01, online=False): - # definte ot problem + # define ot problem x_rng, y_rng = jax.random.split(rng) - mu_a = np.array([-1, 1]) * 5 - mu_b = np.array([0, 0]) + mu_a = jnp.array([-1, 1]) * 5 + mu_b = jnp.array([0, 0]) x = jax.random.normal(x_rng, (n, d)) + mu_a y = jax.random.normal(y_rng, (m, d)) + mu_b - a = np.ones(n) / n - b = np.ones(m) / m + a = jnp.ones(n) / n + b = jnp.ones(m) / m - x_jnp, y_jnp = jnp.array(x), jnp.array(y) batch_size = 3 if online else None - geom = pointcloud.PointCloud( - x_jnp, y_jnp, epsilon=epsilon, batch_size=batch_size - ) + geom = pointcloud.PointCloud(x, y, epsilon=epsilon, batch_size=batch_size) ot_problem = linear_problems.LinearProblem(geom=geom, a=a, b=b) return ot_problem @@ -117,9 +114,9 @@ def run_sinkhorn_gaus_init(x, y, a=None, b=None, epsilon=0.01, lse_mode=True): return out +@pytest.mark.fast class TestInitializers: - @pytest.mark.fast def test_init_pytree(self): @jax.jit @@ -135,14 +132,14 @@ def init_gaus(): init_gaus() init_sort() - @pytest.mark.fast.with_args( + @pytest.mark.parametrize( "vector_min, lse_mode", [(True, True), (True, False), (False, True)] ) - def test_sorting_init(self, vector_min, lse_mode): + def test_sorting_init(self, vector_min: bool, lse_mode: bool): """Tests sorting dual initializer.""" + rng = jax.random.PRNGKey(42) n = 500 epsilon = 0.01 - rng = jax.random.PRNGKey(42) ot_problem = create_sorting_problem( rng=rng, n=n, epsilon=epsilon, online=False @@ -172,40 +169,34 @@ def test_sorting_init(self, vector_min, lse_mode): if lse_mode: assert base_num_iter >= sort_num_iter - @pytest.mark.fast - def test_sorting_init_online(self): + def test_sorting_init_online(self, rng: jnp.ndarray): n = 100 epsilon = 0.01 - rng = jax.random.PRNGKey(0) ot_problem = create_sorting_problem( rng=rng, n=n, epsilon=epsilon, online=True ) sort_init = init_lib.SortingInitializer(vectorized_update=True) - with pytest.raises(AssertionError): + with pytest.raises(AssertionError, match=r"online"): sort_init.init_dual_a(ot_problem=ot_problem, lse_mode=True) - @pytest.mark.fast - def test_sorting_init_square_cost(self): + def test_sorting_init_square_cost(self, rng: jnp.ndarray): n = 100 m = 150 d = 1 epsilon = 0.01 - rng = jax.random.PRNGKey(0) ot_problem = create_ot_problem(rng, n, m, d, epsilon=epsilon, online=False) sort_init = init_lib.SortingInitializer(vectorized_update=True) - with pytest.raises(AssertionError): + with pytest.raises(AssertionError, match=r"square"): sort_init.init_dual_a(ot_problem=ot_problem, lse_mode=True) - @pytest.mark.fast - def test_default_initializer(self): + def test_default_initializer(self, rng: jnp.ndarray): """Tests default initializer""" n = 200 m = 200 d = 2 epsilon = 0.01 - rng = jax.random.PRNGKey(0) ot_problem = create_ot_problem(rng, n, m, d, epsilon=epsilon, online=False) @@ -217,16 +208,14 @@ def test_default_initializer(self): ) # check default is 0 - np.testing.assert_array_equal(jnp.zeros(n), default_potential_a) - np.testing.assert_array_equal(jnp.zeros(m), default_potential_b) + np.testing.assert_array_equal(0., default_potential_a) + np.testing.assert_array_equal(0., default_potential_b) - @pytest.mark.fast - def test_gaus_pointcloud_geom(self): + def test_gauss_pointcloud_geom(self, rng: jnp.ndarray): n = 200 m = 200 d = 2 epsilon = 0.01 - rng = jax.random.PRNGKey(0) ot_problem = create_ot_problem(rng, n, m, d, epsilon=epsilon, online=False) @@ -238,18 +227,17 @@ def test_gaus_pointcloud_geom(self): geom=new_geom, a=ot_problem.a, b=ot_problem.b ) - with pytest.raises(AssertionError): + with pytest.raises(AssertionError, match=r"point cloud"): gaus_init.init_dual_a(ot_problem=ot_problem, lse_mode=True) - @pytest.mark.fast.with_args('lse_mode', [True, False]) - def test_gaus_initializer(self, lse_mode): + @pytest.mark.parametrize('lse_mode', [True, False]) + def test_gauss_initializer(self, lse_mode, rng: jnp.ndarray): """Tests Gaussian initializer""" # definte ot problem n = 200 m = 200 d = 2 epsilon = 0.01 - rng = jax.random.PRNGKey(0) ot_problem = create_ot_problem(rng, n, m, d, epsilon=epsilon, online=False)