Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sparse Chebyshev approximation #502

Merged
merged 15 commits into from
Mar 18, 2024
45 changes: 35 additions & 10 deletions src/ott/geometry/geodesic.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Sequence, Tuple
from typing import Any, Dict, Optional, Sequence, Tuple, Union

import jax
import jax.experimental.sparse as jesp
Expand All @@ -22,10 +22,11 @@
from ott import utils
from ott.geometry import geometry
from ott.math import utils as mu
from ott.types import Array_g

__all__ = ["Geodesic"]

Array_g = Union[jnp.ndarray, jesp.BCOO]


@jax.tree_util.register_pytree_node_class
class Geodesic(geometry.Geometry):
Expand Down Expand Up @@ -106,13 +107,10 @@ def from_graph(
if t is None:
t = (jnp.sum(G) / jnp.sum(G > 0.0)) ** 2

degree = jnp.sum(G, axis=1)
laplacian = jnp.diag(degree) - G
if normalize:
inv_sqrt_deg = jnp.diag(
jnp.where(degree > 0.0, 1.0 / jnp.sqrt(degree), 0.0)
)
laplacian = inv_sqrt_deg @ laplacian @ inv_sqrt_deg
if isinstance(G, jesp.BCOO):
laplacian = compute_sparse_laplacian(G, normalize)
else:
laplacian = compute_dense_laplacian(G, normalize)

if eigval is None:
eigval = compute_largest_eigenvalue(laplacian, rng)
Expand Down Expand Up @@ -220,6 +218,33 @@ def tree_unflatten( # noqa: D102
return cls(*children, **aux_data)


def normalize_laplacian(laplacian: Array_g, degree: jnp.ndarray) -> Array_g:
inv_sqrt_deg = jnp.where(degree > 0.0, 1.0 / jnp.sqrt(degree), 0.0)
return inv_sqrt_deg[:, None] * laplacian * inv_sqrt_deg[None, :]


def compute_dense_laplacian(
G: jnp.ndarray, normalize: bool = False
) -> jnp.ndarray:
degree = jnp.sum(G, axis=1)
laplacian = jnp.diag(degree) - G
if normalize:
laplacian = normalize_laplacian(laplacian, degree)
return laplacian


def compute_sparse_laplacian(
G: jesp.BCOO, normalize: bool = False
) -> jesp.BCOO:
n, _ = G.shape
data_degree, ixs = G.sum(1).todense(), jnp.arange(n)
degree = jesp.BCOO((data_degree, jnp.c_[ixs, ixs]), shape=(n, n))
laplacian = degree - G
if normalize:
laplacian = normalize_laplacian(laplacian, data_degree)
return laplacian


def compute_largest_eigenvalue(
laplacian_matrix: jnp.ndarray,
rng: jax.Array,
Expand All @@ -242,7 +267,7 @@ def compute_largest_eigenvalue(


def expm_multiply(
L: jnp.ndarray, X: jnp.ndarray, coeff: jnp.ndarray, eigval: float
L: Array_g, X: jnp.ndarray, coeff: jnp.ndarray, eigval: float
) -> jnp.ndarray:

def body(carry, c):
Expand Down
8 changes: 2 additions & 6 deletions src/ott/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Protocol, Union
from typing import Protocol

import jax.experimental.sparse as jesp
import jax.numpy as jnp

__all__ = ["Transport", "Array_g"]
__all__ = ["Transport"]

# TODO(michalk8): introduce additional types here

# Either a dense or sparse array.
Array_g = Union[jnp.ndarray, jesp.BCOO]


class Transport(Protocol):
"""Interface for the solution of a transport problem.
Expand Down
29 changes: 27 additions & 2 deletions tests/geometry/geodesic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import Optional, Union

import jax
import jax.experimental.sparse as jesp
import jax.numpy as jnp
import networkx as nx
import numpy as np
Expand All @@ -29,6 +30,7 @@ def random_graph(
n: int,
p: float = 0.3,
seed: Optional[int] = 0,
is_sparse: bool = False,
*,
return_laplacian: bool = False,
directed: bool = False,
Expand All @@ -45,6 +47,8 @@ def random_graph(
G
) if return_laplacian else nx.linalg.adjacency_matrix(G)

if is_sparse:
return jesp.BCOO.from_scipy_sparse(G)
return jnp.asarray(G.toarray())


Expand Down Expand Up @@ -196,7 +200,8 @@ def laplacian(G: jnp.ndarray) -> jnp.ndarray:
np.testing.assert_allclose(actual, expected, rtol=1e-5, atol=1e-5)

@pytest.mark.fast.with_args(jit=[False, True], only_fast=0)
def test_geo_sinkhorn(self, rng: jax.Array, jit: bool):
@pytest.mark.parametrize("is_sparse", [True, False])
def test_geo_sinkhorn(self, rng: jax.Array, jit: bool, is_sparse: bool):

def callback(geom: geometry.Geometry) -> sinkhorn.SinkhornOutput:
solver = sinkhorn.Sinkhorn(lse_mode=False)
Expand All @@ -208,6 +213,8 @@ def callback(geom: geometry.Geometry) -> sinkhorn.SinkhornOutput:
x = jax.random.normal(rng, (n,))

gt_geom = gt_geometry(G, epsilon=eps)
if is_sparse:
G = jesp.BCOO.fromdense(G)
graph_geom = geodesic.Geodesic.from_graph(G, t=eps / 4.0)

fn = jax.jit(callback) if jit else callback
Expand Down Expand Up @@ -257,11 +264,29 @@ def callback(geom: geodesic.Geodesic) -> float:
@pytest.mark.parametrize("normalize", [False, True])
@pytest.mark.parametrize("t", [5, 10, 50])
@pytest.mark.parametrize("order", [20, 30, 40])
def test_heat_approx(self, normalize: bool, t: float, order: int):
@pytest.mark.parametrize("is_sparse", [True, False])
def test_heat_approx(
self, normalize: bool, t: float, order: int, is_sparse: bool
):
G = random_graph(20, p=0.5)
exact = exact_heat_kernel(G, normalize=normalize, t=t)
if is_sparse:
G = jesp.BCOO.fromdense(G)
geom = geodesic.Geodesic.from_graph(
G, t=t, order=order, normalize=normalize
)
approx = geom.apply_kernel(jnp.eye(G.shape[0]))
guillaumehu marked this conversation as resolved.
Show resolved Hide resolved

np.testing.assert_allclose(exact, approx, rtol=1e-1, atol=1e-1)

@pytest.mark.limit_memory("150 MB")
def test_sparse_geo_memory(self, rng: jax.Array):
n = 10_000
G = random_graph(n, p=0.001, is_sparse=True)
x = jax.random.normal(rng, (n,))

graph_geom = geodesic.Geodesic.from_graph(G, t=1.0, order=10)

out = jax.jit(graph_geom.apply_kernel)(x)

np.testing.assert_array_equal(jnp.isfinite(out), True)
Loading