Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/ Introduce initialization methods for Sinkhorn #98

Merged
merged 49 commits into from
Aug 17, 2022
Merged
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
83f8996
add sorting, gaus initializers, add gaus helpers to tools
Jun 30, 2022
3549517
add initialization logic to sinkhorn
Jul 1, 2022
d8cdfd3
remove general ot problem type
Jul 1, 2022
9050efb
remove import tools.gaussian from top level
Jul 1, 2022
a680d6b
remove problems from top level
Jul 1, 2022
378777e
do not register initializer as pytree
Jul 1, 2022
7fa567a
add initializer to make
Jul 1, 2022
4ce8357
rename init arg to ot_problem
Jul 1, 2022
2856b4b
rename init arg to ot_problem
Jul 1, 2022
42eb327
scale gaus init by 2
Jul 1, 2022
2de12cc
typo
Jul 1, 2022
4d80508
add basic speed tests
Jul 1, 2022
23a03e9
add init to transport tools wrapper, tidy docstring
Jul 1, 2022
512150f
ceneter potentials in initializers
Jul 1, 2022
c90cd40
fix lse for null weights
Jul 3, 2022
4f5fbd6
fix flake8 and accidental removal
Jul 3, 2022
9b0f224
tidy docstrings
Jul 3, 2022
d33d89c
tidy docstrings
Jul 3, 2022
d83f913
docstring flake8
Jul 3, 2022
30455a7
flake 8 formatting
Jul 4, 2022
de2f6c4
Merge branch 'ott-jax:main' into main
JTT94 Jul 4, 2022
17a8db9
fix typo
Jul 4, 2022
92924bf
fix stop gradient in Gaussian to include weights and x,y
Jul 4, 2022
077219b
fix stop gradient in Gaussian to include weights and x,y
Jul 4, 2022
e71f6b4
fix docstring spaces
Jul 4, 2022
4c1f0b3
feedback from initial review
Jul 5, 2022
99d0bd1
re order local functions before state init
Jul 5, 2022
60b973b
optional init_f in sorting init
Jul 6, 2022
3e2df88
docstring insert line before return
Jul 6, 2022
14f3b64
lint fix
Jul 6, 2022
5d1f648
incorporate feedback in commit
Jul 12, 2022
aca73e7
tidy tests, use jax.lax.cond for logic instead of if
Jul 13, 2022
86b32f5
add docs, rename sorting initializer
Jul 13, 2022
c12fea8
fix merge conflict
Jul 13, 2022
1c56799
resolve test errors in sinkhorn test
Jul 13, 2022
17e2524
fetch upstream for merge
Jul 13, 2022
161a67a
incorporate feedback, update tests to pytest, change docstrings, intr…
Jul 14, 2022
f6fdd5c
fix docstring spaces
Jul 14, 2022
c189f18
remove spaces and add bibtex
Jul 16, 2022
3855bb9
add errors for non square cost matrix for sorting, online geoms for i…
Aug 17, 2022
4e9c46c
Merge branch 'main' into main
JTT94 Aug 17, 2022
49e5b4f
merge fix lint
Aug 17, 2022
ce4d14c
merge fix lint
Aug 17, 2022
1f93053
add initializers as pytees
Aug 17, 2022
f665911
add init scaling tests
Aug 17, 2022
f4d4c1e
add init scaling tests
Aug 17, 2022
196de5f
simplify vector update flag in sorting initializer
Aug 17, 2022
3dcb736
Fix documentation rendering
michalk8 Aug 17, 2022
69f3050
[ci skip] Fix typo in docs, use fixture in tests
michalk8 Aug 17, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions docs/core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,15 @@ Sinkhorn
sinkhorn.Sinkhorn
sinkhorn.SinkhornOutput

Sinkhorn Dual Initializers
--------------------------
.. autosummary::
:toctree: _autosummary

initializers.SinkhornInitializer
initializers.GaussianInitializer
initializers.SortingInitializer

Low-Rank Sinkhorn
-----------------
.. autosummary::
Expand Down
1 change: 1 addition & 0 deletions ott/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
discrete_barycenter,
gromov_wasserstein,
implicit_differentiation,
initializers,
JTT94 marked this conversation as resolved.
Show resolved Hide resolved
linear_problems,
momentum,
quad_problems,
Expand Down
258 changes: 258 additions & 0 deletions ott/core/initializers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
# Copyright 2022 The OTT Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Sinkhorn initializers."""
from typing import Optional

import jax
import jax.numpy as jnp

from ott.core import linear_problems
from ott.geometry import pointcloud


class SinkhornInitializer:
JTT94 marked this conversation as resolved.
Show resolved Hide resolved

def init_dual_a(
self, ot_problem: linear_problems.LinearProblem, lse_mode: bool
) -> jnp.ndarray:
"""Initialization for Sinkhorn potential/ scaling f_u."""

def init_dual_b(
JTT94 marked this conversation as resolved.
Show resolved Hide resolved
self, ot_problem: linear_problems.LinearProblem, lse_mode: bool
) -> jnp.ndarray:
"""Initialization for Sinkhorn potential/ scaling g_v."""
JTT94 marked this conversation as resolved.
Show resolved Hide resolved


class DefaultInitializer(SinkhornInitializer):
"""Default Initialization of Sinkhorn dual potentials/ primal scalings."""

def init_dual_a(
self, ot_problem: linear_problems.LinearProblem, lse_mode: bool
) -> jnp.ndarray:
"""Initialization for Sinkhorn potential/ scaling f_u.

Args:
ot_problem: OT problem between discrete distributions of size n and m.
lse_mode: Return potential if true, scaling if false.

Returns:
potential/ scaling, array of size n
"""
a = ot_problem.a
init_dual_a = jnp.zeros_like(a) if lse_mode else jnp.ones_like(a)
return init_dual_a

def init_dual_b(
self, ot_problem: linear_problems.LinearProblem, lse_mode: bool
) -> jnp.ndarray:
"""Initialization for Sinkhorn potential/ scaling g_v.

Args:
ot_problem: OT problem between discrete distributions of size n and m.
lse_mode: Return potential if true, scaling if false.

Returns:
potential/ scaling, array of size m
"""
b = ot_problem.b
init_dual_b = jnp.zeros_like(b) if lse_mode else jnp.ones_like(b)
return init_dual_b


class GaussianInitializer(DefaultInitializer):
"""GaussianInitializer.

From https://arxiv.org/abs/2206.07630.
JTT94 marked this conversation as resolved.
Show resolved Hide resolved
Compute Gaussian approximations of each pointcloud, then compute closed from
Kantorovich potential betwen Gaussian approximations using Brenier's theorem
(adapt convex/ Brenier potential to Kantorovich). Use this Gaussian potential to
initialize Sinkhorn potentials/ scalings.

"""

def __init__(self):
JTT94 marked this conversation as resolved.
Show resolved Hide resolved
super().__init__()

def init_dual_a(
self,
ot_problem: linear_problems.LinearProblem,
lse_mode: bool,
) -> jnp.ndarray:
"""Gaussian init function.
JTT94 marked this conversation as resolved.
Show resolved Hide resolved

Args:
ot_problem: OT problem description with geometry and weights.
lse_mode: Return potential if true, scaling if false.

Returns:
potential/ scaling f_u, array of size n.
"""
# import Gaussian here due to circular imports
from ott.tools.gaussian_mixture import gaussian
michalk8 marked this conversation as resolved.
Show resolved Hide resolved

if not isinstance(ot_problem.geom, pointcloud.PointCloud):
# warning that init not applied
return super().init_dual_a(ot_problem, lse_mode)
else:

x, y = ot_problem.geom.x, ot_problem.geom.y
a, b = ot_problem.a, ot_problem.b

gaussian_a = gaussian.Gaussian.from_samples(x, weights=a)
gaussian_b = gaussian.Gaussian.from_samples(y, weights=b)
# Brenier potential for cost ||x-y||^2/2, multiply by two for ||x-y||^2
f_potential = 2 * gaussian_a.f_potential(dest=gaussian_b, points=x)
JTT94 marked this conversation as resolved.
Show resolved Hide resolved
f_potential = f_potential - jnp.mean(f_potential)
f_u = f_potential if lse_mode else ot_problem.scaling_from_potential(
f_potential
)
return f_u


class SortingInitializer(DefaultInitializer):
"""Sorting Init class.

DualSort algorithm from https://arxiv.org/abs/2206.07630, solve
non-regularized OT problem via sorting, then compute potential through
iterated minimum on C-transform and use this potential to initialize
regularized potential

Args:
vectorized_update: Use vectorized inner loop if true.
tolerance: DualSort convergence threshold.
max_iter: Max DualSort steps.
"""

def __init__(
self,
vectorized_update: bool = True,
tolerance: float = 1e-2,
max_iter: int = 100
):

super().__init__()

self.tolerance = tolerance
self.max_iter = max_iter
self.update_fn = lambda f, mod_cost: jax.lax.cond(
JTT94 marked this conversation as resolved.
Show resolved Hide resolved
vectorized_update, _vectorized_update, _coordinate_update, f, mod_cost
)

def init_sorting_dual(
self, modified_cost: jnp.ndarray, init_f: jnp.ndarray
) -> jnp.ndarray:
"""Run DualSort algorithm.

Args:
modified_cost: cost matrix minus diagonal column-wise.
init_f: potential f, array of size n. This is the starting potential,
which is then updated to make the init potential, so an init of an init.

Returns:
potential f, array of size n.
"""

def body_fn(state):
prev_f, _, it = state
new_f = self.update_fn(prev_f, modified_cost)
diff = jnp.sum((new_f - prev_f) ** 2)
it += 1
return new_f, diff, it

def cond_fn(state):
_, diff, it = state
return jnp.logical_and(diff > self.tolerance, it < self.max_iter)

it = 0
diff = self.tolerance + 1.0
state = (init_f, diff, it)

f_potential, _, it = jax.lax.while_loop(
cond_fun=cond_fn, body_fun=body_fn, init_val=state
)

return f_potential

def init_dual_a(
self,
ot_problem: linear_problems.LinearProblem,
lse_mode: bool,
init_f: Optional[jnp.ndarray] = None,
JTT94 marked this conversation as resolved.
Show resolved Hide resolved
) -> jnp.ndarray:
"""Apply DualSort algo.

Args:
ot_problem: OT problem.
lse_mode: Return potential if true, scaling if false.
init_f: potential f, array of size n. This is the starting potential,
which is then updated to make the init potential, so an init of an init.

Returns:
potential/ scaling f_u, array of size n.
"""
if ot_problem.geom.is_online:
# raise error/ warning?
return super().init_dual_a(ot_problem, lse_mode)
else:
cost_matrix = ot_problem.geom.cost_matrix
modified_cost = cost_matrix - jnp.diag(cost_matrix)[None, :]

n = cost_matrix.shape[0]
init_f = jnp.zeros(n) if init_f is None else init_f

f_potential = self.init_sorting_dual(modified_cost, init_f)
f_potential = f_potential - jnp.mean(f_potential)

JTT94 marked this conversation as resolved.
Show resolved Hide resolved
f_u = f_potential if lse_mode else ot_problem.scaling_from_potential(
f_potential
)

return f_u


def _vectorized_update(
f: jnp.ndarray, modified_cost: jnp.ndarray
) -> jnp.ndarray:
"""Inner loop DualSort Update.

Args:
f : potential f, array of size n.
modified_cost: cost matrix minus diagonal column-wise.

Returns:
updated potential vector, f.
"""
f = jnp.min(modified_cost + f[None, :], axis=1)
return f


def _coordinate_update(
f: jnp.ndarray, modified_cost: jnp.ndarray
) -> jnp.ndarray:
"""Coordinate-wise updates within inner loop.

Args:
f: potential f, array of size n.
modified_cost: cost matrix minus diagonal column-wise.

Returns:
updated potential vector, f.
"""

def body_fn(i, f):
new_f = jnp.min(modified_cost[i, :] + f)
f = f.at[i].set(new_f)
return f

return jax.lax.fori_loop(0, len(f), body_fn, f)
24 changes: 19 additions & 5 deletions ott/core/sinkhorn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ott.core import anderson as anderson_lib
from ott.core import fixed_point_loop
from ott.core import implicit_differentiation as implicit_lib
from ott.core import initializers as init_lib
from ott.core import linear_problems
from ott.core import momentum as momentum_lib
from ott.core import unbalanced_functions
Expand Down Expand Up @@ -349,6 +350,8 @@ def __init__(
use_danskin: Optional[bool] = None,
implicit_diff: Optional[implicit_lib.ImplicitDiff
] = implicit_lib.ImplicitDiff(), # noqa: E124
potential_initializer: init_lib.SinkhornInitializer = init_lib
.DefaultInitializer(),
jit: bool = True
):
self.lse_mode = lse_mode
Expand All @@ -368,6 +371,7 @@ def __init__(
self.anderson = anderson
self.implicit_diff = implicit_diff
self.parallel_dual_updates = parallel_dual_updates
self.potential_initializer = potential_initializer
self.jit = jit

# Force implicit_differentiation to True when using Anderson acceleration,
Expand Down Expand Up @@ -400,18 +404,25 @@ def __call__(
init: Optional[Tuple[Optional[jnp.ndarray], Optional[jnp.ndarray]]] = None
) -> SinkhornOutput:
"""Main interface to run sinkhorn.""" # noqa: D401
# initialization
init_dual_a, init_dual_b = (init if init is not None else (None, None))
a, b = ot_prob.a, ot_prob.b

if init_dual_a is None:
init_dual_a = jnp.zeros_like(a) if self.lse_mode else jnp.ones_like(a)
init_dual_a = self.potential_initializer.init_dual_a(
ot_problem=ot_prob, lse_mode=self.lse_mode
)

if init_dual_b is None:
init_dual_b = jnp.zeros_like(b) if self.lse_mode else jnp.ones_like(b)
init_dual_b = self.potential_initializer.init_dual_b(
ot_problem=ot_prob, lse_mode=self.lse_mode
)

# Cancel dual variables for zero weights.
init_dual_a = jnp.where(
a > 0, init_dual_a, -jnp.inf if self.lse_mode else 0.0
ot_prob.a > 0, init_dual_a, -jnp.inf if self.lse_mode else 0.0
)
init_dual_b = jnp.where(
b > 0, init_dual_b, -jnp.inf if self.lse_mode else 0.0
ot_prob.b > 0, init_dual_b, -jnp.inf if self.lse_mode else 0.0
)

run_fn = jax.jit(run) if self.jit else run
Expand Down Expand Up @@ -691,6 +702,8 @@ def make(
precondition_fun: Optional[Callable[[float], float]] = None,
parallel_dual_updates: bool = False,
use_danskin: bool = None,
potential_initializer: init_lib.SinkhornInitializer = init_lib
.DefaultInitializer(),
jit: bool = False
) -> Sinkhorn:
"""For backward compatibility."""
Expand Down Expand Up @@ -725,6 +738,7 @@ def make(
implicit_diff=implicit_diff,
parallel_dual_updates=parallel_dual_updates,
use_danskin=use_danskin,
potential_initializer=potential_initializer,
jit=jit
)

Expand Down
Loading