Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sparse Chebyshev approximation #502

Merged
merged 15 commits into from
Mar 18, 2024
47 changes: 37 additions & 10 deletions src/ott/geometry/geodesic.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Sequence, Tuple
from typing import Any, Dict, Optional, Sequence, Tuple, Union

import jax
import jax.experimental.sparse as jesp
Expand All @@ -27,6 +27,38 @@
__all__ = ["Geodesic"]


guillaumehu marked this conversation as resolved.
Show resolved Hide resolved
def compute_dense_laplacian(
guillaumehu marked this conversation as resolved.
Show resolved Hide resolved
G: jnp.ndarray, normalize: bool = False
) -> jnp.ndarray:
degree = jnp.sum(G, axis=1)
laplacian = jnp.diag(degree) - G
if normalize:
inv_sqrt_deg = jnp.diag(
jnp.where(degree > 0.0, 1.0 / jnp.sqrt(degree), 0.0)
)
laplacian = inv_sqrt_deg @ laplacian @ inv_sqrt_deg
guillaumehu marked this conversation as resolved.
Show resolved Hide resolved
return laplacian


def compute_sparse_laplacian(
G: jesp.BCOO, normalize: bool = False
) -> jesp.BCOO:
n, _ = G.shape
data_degree, ixs = G.sum(1).todense(), jnp.arange(n)
degree = jesp.BCOO((data_degree, jnp.c_[ixs, ixs]), shape=(n, n))
laplacian = degree - G
if normalize:
data_inv = jnp.where(data_degree > 0., 1. / jnp.sqrt(data_degree), 0.)
laplacian = data_inv[:, None] * laplacian * data_inv[None, :]
return laplacian


def compute_laplacian(G: Array_g, normalize: bool = False) -> Array_g:
guillaumehu marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(G, jesp.BCOO):
return compute_sparse_laplacian(G, normalize)
return compute_dense_laplacian(G, normalize)


@jax.tree_util.register_pytree_node_class
class Geodesic(geometry.Geometry):
r"""Graph distance approximation using heat kernel :cite:`huguet:2023`.
Expand Down Expand Up @@ -106,13 +138,7 @@ def from_graph(
if t is None:
t = (jnp.sum(G) / jnp.sum(G > 0.0)) ** 2

degree = jnp.sum(G, axis=1)
laplacian = jnp.diag(degree) - G
if normalize:
inv_sqrt_deg = jnp.diag(
jnp.where(degree > 0.0, 1.0 / jnp.sqrt(degree), 0.0)
)
laplacian = inv_sqrt_deg @ laplacian @ inv_sqrt_deg
laplacian = compute_laplacian(G, normalize)

if eigval is None:
eigval = compute_largest_eigenvalue(laplacian, rng)
Expand Down Expand Up @@ -242,8 +268,9 @@ def compute_largest_eigenvalue(


def expm_multiply(
L: jnp.ndarray, X: jnp.ndarray, coeff: jnp.ndarray, eigval: float
) -> jnp.ndarray:
L: Union[jnp.ndarray, jesp.BCOO], X: jnp.ndarray, coeff: jnp.ndarray,
eigval: float
) -> Union[jnp.ndarray, jesp.BCOO]:
guillaumehu marked this conversation as resolved.
Show resolved Hide resolved

def body(carry, c):
T0, T1, Y = carry
Expand Down
14 changes: 12 additions & 2 deletions tests/geometry/geodesic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import networkx as nx
import numpy as np
import pytest
from jax.experimental import sparse
guillaumehu marked this conversation as resolved.
Show resolved Hide resolved
from networkx.algorithms import shortest_paths
from networkx.generators import balanced_tree, random_graphs
from ott.geometry import geodesic, geometry, graph
Expand Down Expand Up @@ -196,7 +197,8 @@ def laplacian(G: jnp.ndarray) -> jnp.ndarray:
np.testing.assert_allclose(actual, expected, rtol=1e-5, atol=1e-5)

@pytest.mark.fast.with_args(jit=[False, True], only_fast=0)
def test_geo_sinkhorn(self, rng: jax.Array, jit: bool):
@pytest.mark.parametrize("is_sparse", [True, False])
def test_geo_sinkhorn(self, rng: jax.Array, jit: bool, is_sparse: bool):

def callback(geom: geometry.Geometry) -> sinkhorn.SinkhornOutput:
solver = sinkhorn.Sinkhorn(lse_mode=False)
Expand All @@ -208,6 +210,8 @@ def callback(geom: geometry.Geometry) -> sinkhorn.SinkhornOutput:
x = jax.random.normal(rng, (n,))

gt_geom = gt_geometry(G, epsilon=eps)
if is_sparse:
G = sparse.BCOO.fromdense(G)
graph_geom = geodesic.Geodesic.from_graph(G, t=eps / 4.0)

fn = jax.jit(callback) if jit else callback
Expand Down Expand Up @@ -257,11 +261,17 @@ def callback(geom: geodesic.Geodesic) -> float:
@pytest.mark.parametrize("normalize", [False, True])
@pytest.mark.parametrize("t", [5, 10, 50])
@pytest.mark.parametrize("order", [20, 30, 40])
def test_heat_approx(self, normalize: bool, t: float, order: int):
@pytest.mark.parametrize("is_sparse", [True, False])
def test_heat_approx(
self, normalize: bool, t: float, order: int, is_sparse: bool
):
G = random_graph(20, p=0.5)
exact = exact_heat_kernel(G, normalize=normalize, t=t)
if is_sparse:
G = sparse.BCOO.fromdense(G)
geom = geodesic.Geodesic.from_graph(
G, t=t, order=order, normalize=normalize
)
approx = geom.apply_kernel(jnp.eye(G.shape[0]))
guillaumehu marked this conversation as resolved.
Show resolved Hide resolved

np.testing.assert_allclose(exact, approx, rtol=1e-1, atol=1e-1)
Loading