Skip to content

Commit

Permalink
Refactor GW initialization (#133)
Browse files Browse the repository at this point in the history
* Refactor GW initialization

* Fix rank being traced

* Fix not re-using init from previous iters in GW

* Fix tests

* Simplify LRGW iniitalization

* Regenerate GWLR notebook

* Fix old code in test scaling
  • Loading branch information
michalk8 authored Sep 5, 2022
1 parent 1419513 commit 19c537e
Show file tree
Hide file tree
Showing 8 changed files with 101 additions and 65 deletions.
34 changes: 24 additions & 10 deletions docs/notebooks/GWLRSinkhorn.ipynb

Large diffs are not rendered by default.

54 changes: 32 additions & 22 deletions ott/core/gromov_wasserstein.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

# Lint as: python3
"""A Jax version of the regularised GW Solver (Peyre et al. 2016)."""
import functools
from typing import Any, Dict, NamedTuple, Optional, Tuple, Union

import jax
Expand Down Expand Up @@ -101,13 +100,12 @@ class GWState(NamedTuple):
old_transport_mass: Intermediary value of the mass of the transport matrix.
"""

costs: Optional[jnp.ndarray] = None
linear_convergence: Optional[jnp.ndarray] = None
costs: jnp.ndarray
linear_convergence: jnp.ndarray
linear_state: LinearOutput
linear_pb: linear_problems.LinearProblem
old_transport_mass: float
errors: Optional[jnp.ndarray] = None
linear_state: Optional[LinearOutput] = None
linear_pb: Optional[linear_problems.LinearProblem] = None
# Intermediate values.
old_transport_mass: float = 1.0

def set(self, **kwargs: Any) -> 'GWState':
"""Return a copy of self, possibly with overwrites."""
Expand All @@ -125,6 +123,7 @@ def update(
linear_convergence = self.linear_convergence.at[iteration].set(
linear_sol.converged
)

return self.set(
linear_state=linear_sol,
linear_pb=linear_pb,
Expand All @@ -146,8 +145,7 @@ def __call__(self, prob: quad_problems.QuadraticProblem) -> GWOutput:
# Possibly jit iteration functions and run. Closure on rank to
# avoid jitting issues, since rank value will be used to branch between
# a default entropic GW or a low-rank GW.
iterations_fn = functools.partial(iterations, rank=self.rank)
gromov_fn = jax.jit(iterations_fn) if self.jit else iterations_fn
gromov_fn = jax.jit(iterations) if self.jit else iterations
out = gromov_fn(self, prob)
# TODO(lpapaxanthos): remove stop_gradient when using backprop
if self.is_low_rank:
Expand All @@ -167,24 +165,31 @@ def __call__(self, prob: quad_problems.QuadraticProblem) -> GWOutput:
return out.set(linear_state=linear_state, convergence=convergence)

def init_state(
self, prob: quad_problems.QuadraticProblem, rank: int
self,
prob: quad_problems.QuadraticProblem,
) -> GWState:
"""Initialize the state of the Gromov-Wasserstein iterations."""
if rank > 0:
linearization = prob.init_lr_linearization(rank)
if self.is_low_rank:
linear_prob = prob.init_lr_linearization(self.linear_ot_solver)
else:
linearization = prob.init_linearization(self.epsilon)
linear_prob = prob.init_linearization(self.epsilon)

linear_state = self.linear_ot_solver(linearization)
linear_state = self.linear_ot_solver(linear_prob)
num_iter = self.max_iterations
transport_mass = prob.init_transport_mass()

if self.store_inner_errors:
errors = -jnp.ones((num_iter, self.linear_ot_solver.outer_iterations))
else:
errors = None

return GWState(
-jnp.ones((num_iter,)), -jnp.ones((num_iter,)), errors, linear_state,
linearization, transport_mass
costs=-jnp.ones((num_iter,)),
linear_convergence=-jnp.ones((num_iter,)),
linear_state=linear_state,
linear_pb=linear_prob,
old_transport_mass=transport_mass,
errors=errors
)

def output_from_state(self, state: GWState) -> GWOutput:
Expand All @@ -208,7 +213,8 @@ def output_from_state(self, state: GWState) -> GWOutput:


def iterations(
solver: GromovWasserstein, prob: quad_problems.QuadraticProblem, rank: int
solver: GromovWasserstein,
prob: quad_problems.QuadraticProblem,
) -> GWOutput:
"""Jittable Gromov-Wasserstein outer loop."""

Expand All @@ -219,19 +225,21 @@ def cond_fn(
return solver._continue(state, iteration)

def body_fn(
iteration: int, constants: GromovWasserstein, state: GWState,
iteration: int, solver: GromovWasserstein, state: GWState,
compute_error: bool
) -> GWState:
del compute_error # Always assumed True for outer loop of GW.
solver = constants
if rank > 0:

if solver.is_low_rank:
init = state.linear_state.q, state.linear_state.r, state.linear_state.g
linear_pb = prob.update_lr_linearization(state.linear_state)
else:
init = state.linear_state.f, state.linear_state.g
linear_pb = prob.update_linearization(
state.linear_state, solver.epsilon, state.old_transport_mass
)

out = solver.linear_ot_solver(linear_pb)
out = solver.linear_ot_solver(linear_pb, init=init)
old_transport_mass = jax.lax.stop_gradient(
state.linear_state.transport_mass()
)
Expand All @@ -246,7 +254,7 @@ def body_fn(
max_iterations=solver.max_iterations,
inner_iterations=1,
constants=solver,
state=solver.init_state(prob, rank)
state=solver.init_state(prob)
)

return solver.output_from_state(state)
Expand Down Expand Up @@ -300,6 +308,8 @@ def make(
sink = sinkhorn_lr.make(
rank=rank, epsilon=epsilon, **linear_ot_solver_kwargs
)
else:
raise ValueError(f"Invalid value for `rank={rank}`.")

return GromovWasserstein(
epsilon,
Expand Down
6 changes: 3 additions & 3 deletions ott/core/gw_barycenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,17 +241,17 @@ def output_from_state(self, state: GWBarycenterState) -> GWBarycenterState:

def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]:
children, aux = super().tree_flatten()
aux["quad_solver"] = self._quad_solver
return children, aux
return children + [self._quad_solver], aux

@classmethod
def tree_unflatten(
cls, aux_data: Dict[str, Any], children: Sequence[Any]
) -> "GromovWassersteinBarycenter":
epsilon, _, _, threshold = children
epsilon, _, threshold, quad_solver = children
return cls(
epsilon=epsilon,
threshold=threshold,
quad_solver=quad_solver,
**aux_data,
)

Expand Down
48 changes: 31 additions & 17 deletions ott/core/quad_problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,18 @@ def make_kl_loss(clipping_value: float = 1e-8) -> GWLoss:

@jax.tree_util.register_pytree_node_class
class QuadraticProblem:
"""Definition of the quadratic regularized OT problem.
r"""Definition of the quadratic regularized OT problem.
The quadratic loss of a single OT matrix is assumed to
have the form given in :cite:`peyre:16`, eq. 4.
The two geometries below parameterize matrices C and bar{C} in that equation.
The function L (of two real values) in that equation is assumed
to match the form given in Eq. 5., with our notations:
The two geometries below parameterize matrices :math:`C` and :math:`\bar{C}`
in that equation. The function :math:`L` (of two real values) in that equation
is assumed to match the form given in eq. 5., with our notations:
L(x, y) = lin1(x) + lin2(y) - quad1(x) * quad2(y)
.. math::
L(x, y) = lin1(x) + lin2(y) - quad1(x) * quad2(y)
Args:
geom_xx: the geometry.Geometry object defining the ground geometry / cost
Expand Down Expand Up @@ -175,10 +177,12 @@ def __init__(

@property
def is_fused(self) -> bool:
"""Whether the problem is fused."""
return self.geom_xy is not None

@property
def is_low_rank(self) -> bool:
"""Whether all geometries are low-rank."""
return (
isinstance(self.geom_xx, low_rank.LRCGeometry) and
isinstance(self.geom_yy, low_rank.LRCGeometry) and (
Expand All @@ -189,14 +193,17 @@ def is_low_rank(self) -> bool:

@property
def linear_loss(self) -> Tuple[Loss, Loss]:
"""Linear part of the GW loss."""
return self.loss.f1, self.loss.f2

@property
def quad_loss(self) -> Tuple[Loss, Loss]:
"""Quadratic part of the GW loss."""
return self.loss.h1, self.loss.h2

@property
def is_balanced(self) -> bool:
"""Whether the problem is balanced."""
return ((not self.gw_unbalanced_correction) or
(self.tau_a == 1.0 and self.tau_b == 1.0))

Expand All @@ -219,11 +226,13 @@ def tree_unflatten(cls, aux_data, children):

@property
def a(self) -> jnp.ndarray:
"""Source marginals."""
num_a = self.geom_xx.shape[0]
return jnp.ones((num_a,)) / num_a if self._a is None else self._a

@property
def b(self) -> jnp.ndarray:
"""Target marginals."""
num_b = self.geom_yy.shape[0]
return jnp.ones((num_b,)) / num_b if self._b is None else self._b

Expand Down Expand Up @@ -416,24 +425,29 @@ def init_linearization(
)

def init_lr_linearization(
self, rank: int, **kwargs: Any
self,
solver: sinkhorn_lr.LRSinkhorn,
**kwargs: Any,
) -> linear_problems.LinearProblem:
"""Linearizes a Quad problem with a predefined initializer."""
x_ = self.geom_xx.apply_square_cost(self.a)
y_ = self.geom_yy.apply_square_cost(self.b)
geom_ = pointcloud.PointCloud(x_, y_).to_LRCGeometry()
out = sinkhorn_lr.LRSinkhorn(
rank=rank, **kwargs
)(
linear_problems.LinearProblem(geom_, self.a, self.b)
"""Linearize a Quad problem with a predefined initializer."""
x = self.geom_xx.apply_square_cost(self.a)
y = self.geom_yy.apply_square_cost(self.b)
geom = pointcloud.PointCloud(x, y).to_LRCGeometry()

prob = linear_problems.LinearProblem(geom, self.a, self.b)
q, r, g = solver.initializer(prob, **kwargs)
dummy_out = sinkhorn_lr.LRSinkhornOutput(
q=q, r=r, g=g, costs=None, criterions=None, ot_prob=prob
)
return linear_problems.LinearProblem(
self.update_lr_geom(out),

prob = linear_problems.LinearProblem(
self.update_lr_geom(dummy_out),
self.a,
self.b,
tau_a=self.tau_a,
tau_b=self.tau_b
)
return prob

def update_lr_geom(
self, lr_sink: sinkhorn_lr.LRSinkhornOutput
Expand Down Expand Up @@ -554,7 +568,7 @@ def convertible(geom: geometry.Geometry) -> bool:

geom_xx, geom_yy, geom_xy = self.geom_xx, self.geom_yy, self.geom_xy
# either explicitly via cost factorization or implicitly (e.g., a PC)
return self.ranks != 1 or (
return self.ranks != -1 or (
convertible(geom_xx) and convertible(geom_yy) and
(geom_xy is None or convertible(geom_xy))
)
Expand Down
6 changes: 3 additions & 3 deletions ott/core/was_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,21 +80,21 @@ def is_low_rank(self) -> bool:
return self.rank > 0

def tree_flatten(self):
return ([self.epsilon, self.rank, self.linear_ot_solver, self.threshold],
return ([self.epsilon, self.linear_ot_solver, self.threshold],
dict(
min_iterations=self.min_iterations,
max_iterations=self.max_iterations,
jit=self.jit,
rank=self.rank,
store_inner_errors=self.store_inner_errors,
**self._kwargs
))

@classmethod
def tree_unflatten(cls, aux_data, children):
epsilon, rank, linear_ot_solver, threshold = children
epsilon, linear_ot_solver, threshold = children
return cls(
epsilon=epsilon,
rank=rank,
linear_ot_solver=linear_ot_solver,
threshold=threshold,
**aux_data
Expand Down
4 changes: 1 addition & 3 deletions ott/geometry/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,7 @@ def inv_scale_cost(self) -> float:
return 1.0 / jnp.nanmedian(self._cost_matrix)
raise ValueError(f'Scaling {self._scale_cost} not implemented.')

def _set_scale_cost(
self, scale_cost: Optional[Union[bool, float, str]]
) -> "Geometry":
def _set_scale_cost(self, scale_cost: Union[bool, float, str]) -> "Geometry":
# case when `geom` doesn't have `scale_cost` or doesn't need to be modified
# `False` retains the original scale
if scale_cost is False or scale_cost == self._scale_cost:
Expand Down
10 changes: 5 additions & 5 deletions tests/core/gromov_wasserstein_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,18 +358,18 @@ def test_gw_lr_matches_fused(self, rng: jnp.ndarray):
ot_gw = solver(prob)

# Test solutions look alike
assert 0.1 > jnp.linalg.norm(ot_gwlr.matrix - ot_gw.matrix)
assert 0.13 > jnp.linalg.norm(ot_gwlr.matrix - ot_gwlreps.matrix)
assert 0.11 > jnp.linalg.norm(ot_gwlr.matrix - ot_gw.matrix)
assert 0.15 > jnp.linalg.norm(ot_gwlr.matrix - ot_gwlreps.matrix)
# Test at least some difference when adding bigger entropic regularization
assert jnp.linalg.norm(ot_gwlr.matrix - ot_gwlreps.matrix) > 1e-3

@pytest.mark.parametrize("scale_cost", [True, "mean", "max_cost"])
def test_gw_fused_scale_cost(self, scale_cost: Union[bool, str]):
epsilon = 0.1
fused_penalty = 1
geom_x = pointcloud.PointCloud(self.x, scale_cost=None)
geom_y = pointcloud.PointCloud(self.y, scale_cost=None)
geom_xy = pointcloud.PointCloud(self.xx, self.yy, scale_cost=None)
geom_x = pointcloud.PointCloud(self.x, scale_cost=1.)
geom_y = pointcloud.PointCloud(self.y, scale_cost=1.)
geom_xy = pointcloud.PointCloud(self.xx, self.yy, scale_cost=1.)
geom_x_scaled = pointcloud.PointCloud(self.x, scale_cost=scale_cost)
geom_y_scaled = pointcloud.PointCloud(self.y, scale_cost=scale_cost)
geom_xy_scaled = pointcloud.PointCloud(
Expand Down
4 changes: 2 additions & 2 deletions tests/core/sinkhorn_diff_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,7 +734,7 @@ class TestSinkhornHessian:
tau_b=[1.0, .91],
shape=[(12, 15)],
arg=[0, 1],
only_fast=[-1]
only_fast=-1
)
def test_hessian_sinkhorn(
self, rng: jnp.ndarray, lse_mode: bool, tau_a: float, tau_b: float,
Expand Down Expand Up @@ -764,7 +764,7 @@ def loss(a: jnp.ndarray, x: jnp.ndarray, implicit: bool = True):
lse_mode=lse_mode,
threshold=1e-4,
use_danskin=False,
implicit_diff=implicit_diff
implicit_diff=implicit_diff,
)
return solver(prob).reg_ot_cost

Expand Down

0 comments on commit 19c537e

Please sign in to comment.