diff --git a/src/ott/geometry/geodesic.py b/src/ott/geometry/geodesic.py index 9375b2364..563352920 100644 --- a/src/ott/geometry/geodesic.py +++ b/src/ott/geometry/geodesic.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Sequence, Tuple +from typing import Any, Dict, Optional, Sequence, Tuple, Union import jax import jax.experimental.sparse as jesp @@ -22,10 +22,11 @@ from ott import utils from ott.geometry import geometry from ott.math import utils as mu -from ott.types import Array_g __all__ = ["Geodesic"] +Array_g = Union[jnp.ndarray, jesp.BCOO] + @jax.tree_util.register_pytree_node_class class Geodesic(geometry.Geometry): @@ -106,13 +107,10 @@ def from_graph( if t is None: t = (jnp.sum(G) / jnp.sum(G > 0.0)) ** 2 - degree = jnp.sum(G, axis=1) - laplacian = jnp.diag(degree) - G - if normalize: - inv_sqrt_deg = jnp.diag( - jnp.where(degree > 0.0, 1.0 / jnp.sqrt(degree), 0.0) - ) - laplacian = inv_sqrt_deg @ laplacian @ inv_sqrt_deg + if isinstance(G, jesp.BCOO): + laplacian = compute_sparse_laplacian(G, normalize) + else: + laplacian = compute_dense_laplacian(G, normalize) if eigval is None: eigval = compute_largest_eigenvalue(laplacian, rng) @@ -220,6 +218,33 @@ def tree_unflatten( # noqa: D102 return cls(*children, **aux_data) +def normalize_laplacian(laplacian: Array_g, degree: jnp.ndarray) -> Array_g: + inv_sqrt_deg = jnp.where(degree > 0.0, 1.0 / jnp.sqrt(degree), 0.0) + return inv_sqrt_deg[:, None] * laplacian * inv_sqrt_deg[None, :] + + +def compute_dense_laplacian( + G: jnp.ndarray, normalize: bool = False +) -> jnp.ndarray: + degree = jnp.sum(G, axis=1) + laplacian = jnp.diag(degree) - G + if normalize: + laplacian = normalize_laplacian(laplacian, degree) + return laplacian + + +def compute_sparse_laplacian( + G: jesp.BCOO, normalize: bool = False +) -> jesp.BCOO: + n, _ = G.shape + data_degree, ixs = G.sum(1).todense(), jnp.arange(n) + degree = jesp.BCOO((data_degree, jnp.c_[ixs, ixs]), shape=(n, n)) + laplacian = degree - G + if normalize: + laplacian = normalize_laplacian(laplacian, data_degree) + return laplacian + + def compute_largest_eigenvalue( laplacian_matrix: jnp.ndarray, rng: jax.Array, @@ -242,7 +267,7 @@ def compute_largest_eigenvalue( def expm_multiply( - L: jnp.ndarray, X: jnp.ndarray, coeff: jnp.ndarray, eigval: float + L: Array_g, X: jnp.ndarray, coeff: jnp.ndarray, eigval: float ) -> jnp.ndarray: def body(carry, c): diff --git a/src/ott/types.py b/src/ott/types.py index f6eefb797..7a4c88716 100644 --- a/src/ott/types.py +++ b/src/ott/types.py @@ -11,18 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Protocol, Union +from typing import Protocol -import jax.experimental.sparse as jesp import jax.numpy as jnp -__all__ = ["Transport", "Array_g"] +__all__ = ["Transport"] # TODO(michalk8): introduce additional types here -# Either a dense or sparse array. -Array_g = Union[jnp.ndarray, jesp.BCOO] - class Transport(Protocol): """Interface for the solution of a transport problem. diff --git a/tests/geometry/geodesic_test.py b/tests/geometry/geodesic_test.py index 45877cc34..97867a78d 100644 --- a/tests/geometry/geodesic_test.py +++ b/tests/geometry/geodesic_test.py @@ -14,6 +14,7 @@ from typing import Optional, Union import jax +import jax.experimental.sparse as jesp import jax.numpy as jnp import networkx as nx import numpy as np @@ -29,6 +30,7 @@ def random_graph( n: int, p: float = 0.3, seed: Optional[int] = 0, + is_sparse: bool = False, *, return_laplacian: bool = False, directed: bool = False, @@ -45,6 +47,8 @@ def random_graph( G ) if return_laplacian else nx.linalg.adjacency_matrix(G) + if is_sparse: + return jesp.BCOO.from_scipy_sparse(G) return jnp.asarray(G.toarray()) @@ -196,7 +200,8 @@ def laplacian(G: jnp.ndarray) -> jnp.ndarray: np.testing.assert_allclose(actual, expected, rtol=1e-5, atol=1e-5) @pytest.mark.fast.with_args(jit=[False, True], only_fast=0) - def test_geo_sinkhorn(self, rng: jax.Array, jit: bool): + @pytest.mark.parametrize("is_sparse", [True, False]) + def test_geo_sinkhorn(self, rng: jax.Array, jit: bool, is_sparse: bool): def callback(geom: geometry.Geometry) -> sinkhorn.SinkhornOutput: solver = sinkhorn.Sinkhorn(lse_mode=False) @@ -208,6 +213,8 @@ def callback(geom: geometry.Geometry) -> sinkhorn.SinkhornOutput: x = jax.random.normal(rng, (n,)) gt_geom = gt_geometry(G, epsilon=eps) + if is_sparse: + G = jesp.BCOO.fromdense(G) graph_geom = geodesic.Geodesic.from_graph(G, t=eps / 4.0) fn = jax.jit(callback) if jit else callback @@ -257,11 +264,29 @@ def callback(geom: geodesic.Geodesic) -> float: @pytest.mark.parametrize("normalize", [False, True]) @pytest.mark.parametrize("t", [5, 10, 50]) @pytest.mark.parametrize("order", [20, 30, 40]) - def test_heat_approx(self, normalize: bool, t: float, order: int): + @pytest.mark.parametrize("is_sparse", [True, False]) + def test_heat_approx( + self, normalize: bool, t: float, order: int, is_sparse: bool + ): G = random_graph(20, p=0.5) exact = exact_heat_kernel(G, normalize=normalize, t=t) + if is_sparse: + G = jesp.BCOO.fromdense(G) geom = geodesic.Geodesic.from_graph( G, t=t, order=order, normalize=normalize ) approx = geom.apply_kernel(jnp.eye(G.shape[0])) + np.testing.assert_allclose(exact, approx, rtol=1e-1, atol=1e-1) + + @pytest.mark.limit_memory("150 MB") + def test_sparse_geo_memory(self, rng: jax.Array): + n = 10_000 + G = random_graph(n, p=0.001, is_sparse=True) + x = jax.random.normal(rng, (n,)) + + graph_geom = geodesic.Geodesic.from_graph(G, t=1.0, order=10) + + out = jax.jit(graph_geom.apply_kernel)(x) + + np.testing.assert_array_equal(jnp.isfinite(out), True)