Skip to content

Commit

Permalink
Add sparse Chebyshev approximation (#502)
Browse files Browse the repository at this point in the history
* wip bcoo geodesic

* add sparse laplacian

* add test & pass test

* typo & fix type

* appease code formatter

* fmt

* fix laplacian type

* norm lap with elementwise multiplication

* unify tests

* fix sparse scan & sinkhorn test

* fmt

* rm mv to sparse since `@jesp.sparsify`

* rm sparsify wrapper and fix type

* fix type & mv fn & test memory
  • Loading branch information
guillaumehu authored Mar 18, 2024
1 parent da704fc commit 51a658a
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 18 deletions.
45 changes: 35 additions & 10 deletions src/ott/geometry/geodesic.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Sequence, Tuple
from typing import Any, Dict, Optional, Sequence, Tuple, Union

import jax
import jax.experimental.sparse as jesp
Expand All @@ -22,10 +22,11 @@
from ott import utils
from ott.geometry import geometry
from ott.math import utils as mu
from ott.types import Array_g

__all__ = ["Geodesic"]

Array_g = Union[jnp.ndarray, jesp.BCOO]


@jax.tree_util.register_pytree_node_class
class Geodesic(geometry.Geometry):
Expand Down Expand Up @@ -106,13 +107,10 @@ def from_graph(
if t is None:
t = (jnp.sum(G) / jnp.sum(G > 0.0)) ** 2

degree = jnp.sum(G, axis=1)
laplacian = jnp.diag(degree) - G
if normalize:
inv_sqrt_deg = jnp.diag(
jnp.where(degree > 0.0, 1.0 / jnp.sqrt(degree), 0.0)
)
laplacian = inv_sqrt_deg @ laplacian @ inv_sqrt_deg
if isinstance(G, jesp.BCOO):
laplacian = compute_sparse_laplacian(G, normalize)
else:
laplacian = compute_dense_laplacian(G, normalize)

if eigval is None:
eigval = compute_largest_eigenvalue(laplacian, rng)
Expand Down Expand Up @@ -220,6 +218,33 @@ def tree_unflatten( # noqa: D102
return cls(*children, **aux_data)


def normalize_laplacian(laplacian: Array_g, degree: jnp.ndarray) -> Array_g:
inv_sqrt_deg = jnp.where(degree > 0.0, 1.0 / jnp.sqrt(degree), 0.0)
return inv_sqrt_deg[:, None] * laplacian * inv_sqrt_deg[None, :]


def compute_dense_laplacian(
G: jnp.ndarray, normalize: bool = False
) -> jnp.ndarray:
degree = jnp.sum(G, axis=1)
laplacian = jnp.diag(degree) - G
if normalize:
laplacian = normalize_laplacian(laplacian, degree)
return laplacian


def compute_sparse_laplacian(
G: jesp.BCOO, normalize: bool = False
) -> jesp.BCOO:
n, _ = G.shape
data_degree, ixs = G.sum(1).todense(), jnp.arange(n)
degree = jesp.BCOO((data_degree, jnp.c_[ixs, ixs]), shape=(n, n))
laplacian = degree - G
if normalize:
laplacian = normalize_laplacian(laplacian, data_degree)
return laplacian


def compute_largest_eigenvalue(
laplacian_matrix: jnp.ndarray,
rng: jax.Array,
Expand All @@ -242,7 +267,7 @@ def compute_largest_eigenvalue(


def expm_multiply(
L: jnp.ndarray, X: jnp.ndarray, coeff: jnp.ndarray, eigval: float
L: Array_g, X: jnp.ndarray, coeff: jnp.ndarray, eigval: float
) -> jnp.ndarray:

def body(carry, c):
Expand Down
8 changes: 2 additions & 6 deletions src/ott/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Protocol, Union
from typing import Protocol

import jax.experimental.sparse as jesp
import jax.numpy as jnp

__all__ = ["Transport", "Array_g"]
__all__ = ["Transport"]

# TODO(michalk8): introduce additional types here

# Either a dense or sparse array.
Array_g = Union[jnp.ndarray, jesp.BCOO]


class Transport(Protocol):
"""Interface for the solution of a transport problem.
Expand Down
29 changes: 27 additions & 2 deletions tests/geometry/geodesic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import Optional, Union

import jax
import jax.experimental.sparse as jesp
import jax.numpy as jnp
import networkx as nx
import numpy as np
Expand All @@ -29,6 +30,7 @@ def random_graph(
n: int,
p: float = 0.3,
seed: Optional[int] = 0,
is_sparse: bool = False,
*,
return_laplacian: bool = False,
directed: bool = False,
Expand All @@ -45,6 +47,8 @@ def random_graph(
G
) if return_laplacian else nx.linalg.adjacency_matrix(G)

if is_sparse:
return jesp.BCOO.from_scipy_sparse(G)
return jnp.asarray(G.toarray())


Expand Down Expand Up @@ -196,7 +200,8 @@ def laplacian(G: jnp.ndarray) -> jnp.ndarray:
np.testing.assert_allclose(actual, expected, rtol=1e-5, atol=1e-5)

@pytest.mark.fast.with_args(jit=[False, True], only_fast=0)
def test_geo_sinkhorn(self, rng: jax.Array, jit: bool):
@pytest.mark.parametrize("is_sparse", [True, False])
def test_geo_sinkhorn(self, rng: jax.Array, jit: bool, is_sparse: bool):

def callback(geom: geometry.Geometry) -> sinkhorn.SinkhornOutput:
solver = sinkhorn.Sinkhorn(lse_mode=False)
Expand All @@ -208,6 +213,8 @@ def callback(geom: geometry.Geometry) -> sinkhorn.SinkhornOutput:
x = jax.random.normal(rng, (n,))

gt_geom = gt_geometry(G, epsilon=eps)
if is_sparse:
G = jesp.BCOO.fromdense(G)
graph_geom = geodesic.Geodesic.from_graph(G, t=eps / 4.0)

fn = jax.jit(callback) if jit else callback
Expand Down Expand Up @@ -257,11 +264,29 @@ def callback(geom: geodesic.Geodesic) -> float:
@pytest.mark.parametrize("normalize", [False, True])
@pytest.mark.parametrize("t", [5, 10, 50])
@pytest.mark.parametrize("order", [20, 30, 40])
def test_heat_approx(self, normalize: bool, t: float, order: int):
@pytest.mark.parametrize("is_sparse", [True, False])
def test_heat_approx(
self, normalize: bool, t: float, order: int, is_sparse: bool
):
G = random_graph(20, p=0.5)
exact = exact_heat_kernel(G, normalize=normalize, t=t)
if is_sparse:
G = jesp.BCOO.fromdense(G)
geom = geodesic.Geodesic.from_graph(
G, t=t, order=order, normalize=normalize
)
approx = geom.apply_kernel(jnp.eye(G.shape[0]))

np.testing.assert_allclose(exact, approx, rtol=1e-1, atol=1e-1)

@pytest.mark.limit_memory("150 MB")
def test_sparse_geo_memory(self, rng: jax.Array):
n = 10_000
G = random_graph(n, p=0.001, is_sparse=True)
x = jax.random.normal(rng, (n,))

graph_geom = geodesic.Geodesic.from_graph(G, t=1.0, order=10)

out = jax.jit(graph_geom.apply_kernel)(x)

np.testing.assert_array_equal(jnp.isfinite(out), True)

0 comments on commit 51a658a

Please sign in to comment.