From 09810be0cd6037c2c24b2297159bed187da49ff3 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 7 May 2024 13:40:25 -0700 Subject: [PATCH] Implement jnp.linalg.multi_dot using opt_einsum --- docs/aot.md | 2 + jax/_src/numpy/linalg.py | 94 +++++++++++++++++++++++ jax/_src/third_party/numpy/linalg.py | 110 --------------------------- jax/numpy/linalg.py | 2 +- 4 files changed, 97 insertions(+), 111 deletions(-) diff --git a/docs/aot.md b/docs/aot.md index 1a7ec0080e61..3304f4081b6a 100644 --- a/docs/aot.md +++ b/docs/aot.md @@ -1,3 +1,5 @@ +(ahead-of-time-lowering)= + # Ahead-of-time lowering and compilation JAX offers several transformations, such as `jax.jit` and `jax.pmap`, returning diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index e8ff77242d9f..e60b283012e8 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -16,6 +16,7 @@ from collections.abc import Sequence from functools import partial +import itertools import math import warnings @@ -28,6 +29,7 @@ from jax import lax from jax._src.lax import lax as lax_internal +from jax._src.lax.lax import PrecisionLike from jax._src.lax import linalg as lax_linalg from jax._src.numpy import lax_numpy as jnp from jax._src.numpy import reductions, ufuncs @@ -1924,3 +1926,95 @@ def tensorsolve(a: ArrayLike, b: ArrayLike, axes: tuple[int, ...] | None = None) f" got a.shape={a_arr.shape}, b.ndim={b_arr.ndim}.") a_arr = a_arr.reshape(b_arr.size, math.prod(out_shape)) return solve(a_arr, b_arr.ravel()).reshape(out_shape) + + +def multi_dot(arrays: Sequence[ArrayLike], *, precision: PrecisionLike = None) -> Array: + """Efficiently compute matrix products between a sequence of arrays. + + JAX implementation of :func:`numpy.linalg.multi_dot`. + + JAX internally uses the opt_einsum library to compute the most efficient + operation order. + + Args: + arrays: sequence of arrays. All must be two-dimensional, except the first + and last which may be one-dimensional. + precision: either ``None`` (default), which means the default precision for + the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``, + ``Precision.HIGH`` or ``Precision.HIGHEST``). + + Returns: + an array representing the equivalent of ``reduce(jnp.matmul, arrays)``, but + evaluated in the optimal order. + + This function exists because the cost of computing sequences of matmul operations + can differ vastly depending on the order in which the operations are evaluated. + For a single matmul, the number of floating point operations (flops) required to + compute a matrix product can be approximated this way: + + >>> def approx_flops(x, y): + ... # for 2D x and y, with x.shape[1] == y.shape[0] + ... return 2 * x.shape[0] * x.shape[1] * y.shape[1] + + Suppose we have three matrices that we'd like to multiply in sequence: + + >>> key1, key2, key3 = jax.random.split(jax.random.key(0), 3) + >>> x = jax.random.normal(key1, shape=(200, 5)) + >>> y = jax.random.normal(key2, shape=(5, 100)) + >>> z = jax.random.normal(key3, shape=(100, 10)) + + Because of associativity of matrix products, there are two orders in which we might + evaluate the product ``x @ y @ z``, and both produce equivalent outputs up to floating + point precision: + + >>> result1 = (x @ y) @ z + >>> result2 = x @ (y @ z) + >>> jnp.allclose(result1, result2, atol=1E-4) + Array(True, dtype=bool) + + But the computational cost of these differ greatly: + + >>> print("(x @ y) @ z flops:", approx_flops(x, y) + approx_flops(x @ y, z)) + (x @ y) @ z flops: 600000 + >>> print("x @ (y @ z) flops:", approx_flops(y, z) + approx_flops(x, y @ z)) + x @ (y @ z) flops: 30000 + + The second approach is about 20x more efficient in terms of estimated flops! + + ``multi_dot`` is a function that will automatically choose the fastest + computational path for such problems: + + >>> result3 = jnp.linalg.multi_dot([x, y, z]) + >>> jnp.allclose(result1, result3, atol=1E-4) + Array(True, dtype=bool) + + We can use JAX's :ref:`ahead-of-time-lowering` tools to estimate the total flops + of each approach, and confirm that ``multi_dot`` is choosing the more efficient + option: + + >>> jax.jit(lambda x, y, z: (x @ y) @ z).lower(x, y, z).cost_analysis()['flops'] + 600000.0 + >>> jax.jit(lambda x, y, z: x @ (y @ z)).lower(x, y, z).cost_analysis()['flops'] + 30000.0 + >>> jax.jit(jnp.linalg.multi_dot).lower([x, y, z]).cost_analysis()['flops'] + 30000.0 + """ + check_arraylike('jnp.linalg.multi_dot', *arrays) + arrs: list[Array] = list(map(jnp.asarray, arrays)) + if len(arrs) < 2: + raise ValueError(f"multi_dot requires at least two arrays; got len(arrays)={len(arrs)}") + if not (arrs[0].ndim in (1, 2) and arrs[-1].ndim in (1, 2) and + all(a.ndim == 2 for a in arrs[1:-1])): + raise ValueError("multi_dot: input arrays must all be two-dimensional, except for" + " the first and last array which may be 1 or 2 dimensional." + f" Got array shapes {[a.shape for a in arrs]}") + if any(a.shape[-1] != b.shape[0] for a, b in zip(arrs[:-1], arrs[1:])): + raise ValueError("multi_dot: last dimension of each array must match first dimension" + f" of following array. Got array shapes {[a.shape for a in arrs]}") + einsum_axes: list[tuple[int, ...]] = [(i, i+1) for i in range(len(arrs))] + if arrs[0].ndim == 1: + einsum_axes[0] = einsum_axes[0][1:] + if arrs[-1].ndim == 1: + einsum_axes[-1] = einsum_axes[-1][:1] + return jnp.einsum(*itertools.chain(*zip(arrs, einsum_axes)), # type: ignore[arg-type, call-overload] + optimize='optimal', precision=precision) diff --git a/jax/_src/third_party/numpy/linalg.py b/jax/_src/third_party/numpy/linalg.py index db1dd51bf742..db0bd99c7360 100644 --- a/jax/_src/third_party/numpy/linalg.py +++ b/jax/_src/third_party/numpy/linalg.py @@ -32,13 +32,6 @@ def _assertNdSquareness(*arrays): 'Last 2 dimensions of the array must be square') -def _assert2d(*arrays): - for a in arrays: - if a.ndim != 2: - raise ValueError(f'{a.ndim}-dimensional array given. ' - 'Array must be two-dimensional') - - @implements(np.linalg.cond) def cond(x, p=None): check_arraylike('jnp.linalg.cond', x) @@ -60,106 +53,3 @@ def cond(x, p=None): nan_mask = jnp.logical_and(jnp.isnan(r), ~jnp.isnan(x).any(axis=(-2, -1))) r = jnp.where(orig_nan_check, jnp.where(nan_mask, jnp.inf, r), r) return r - - -@implements(np.linalg.multi_dot) -def multi_dot(arrays, *, precision=None): - check_arraylike('jnp.linalg.multi_dot', *arrays) - n = len(arrays) - # optimization only makes sense for len(arrays) > 2 - if n < 2: - raise ValueError("Expecting at least two arrays.") - elif n == 2: - return jnp.dot(arrays[0], arrays[1], precision=precision) - - arrays = [jnp.asarray(a) for a in arrays] - - # save original ndim to reshape the result array into the proper form later - ndim_first, ndim_last = arrays[0].ndim, arrays[-1].ndim - # Explicitly convert vectors to 2D arrays to keep the logic of the internal - # _multi_dot_* functions as simple as possible. - if arrays[0].ndim == 1: - arrays[0] = jnp.atleast_2d(arrays[0]) - if arrays[-1].ndim == 1: - arrays[-1] = jnp.atleast_2d(arrays[-1]).T - _assert2d(*arrays) - - # _multi_dot_three is much faster than _multi_dot_matrix_chain_order - if n == 3: - result = _multi_dot_three(*arrays, precision) - else: - order = _multi_dot_matrix_chain_order(arrays) - result = _multi_dot(arrays, order, 0, n - 1, precision) - - # return proper shape - if ndim_first == 1 and ndim_last == 1: - return result[0, 0] # scalar - elif ndim_first == 1 or ndim_last == 1: - return result.ravel() # 1-D - else: - return result - - -def _multi_dot_three(A, B, C, precision): - """ - Find the best order for three arrays and do the multiplication. - For three arguments `_multi_dot_three` is approximately 15 times faster - than `_multi_dot_matrix_chain_order` - """ - a0, a1b0 = A.shape - b1c0, c1 = C.shape - # cost1 = cost((AB)C) = a0*a1b0*b1c0 + a0*b1c0*c1 - cost1 = a0 * b1c0 * (a1b0 + c1) - # cost2 = cost(A(BC)) = a1b0*b1c0*c1 + a0*a1b0*c1 - cost2 = a1b0 * c1 * (a0 + b1c0) - - if cost1 < cost2: - return jnp.dot(jnp.dot(A, B, precision=precision), C, precision=precision) - else: - return jnp.dot(A, jnp.dot(B, C, precision=precision), precision=precision) - - -def _multi_dot_matrix_chain_order(arrays, return_costs=False): - """ - Return a jnp.array that encodes the optimal order of mutiplications. - The optimal order array is then used by `_multi_dot()` to do the - multiplication. - Also return the cost matrix if `return_costs` is `True` - The implementation CLOSELY follows Cormen, "Introduction to Algorithms", - Chapter 15.2, p. 370-378. Note that Cormen uses 1-based indices. - cost[i, j] = min([ - cost[prefix] + cost[suffix] + cost_mult(prefix, suffix) - for k in range(i, j)]) - """ - n = len(arrays) - # p stores the dimensions of the matrices - # Example for p: A_{10x100}, B_{100x5}, C_{5x50} --> p = [10, 100, 5, 50] - p = [a.shape[0] for a in arrays] + [arrays[-1].shape[1]] - # m is a matrix of costs of the subproblems - # m[i,j]: min number of scalar multiplications needed to compute A_{i..j} - m = np.zeros((n, n), dtype=np.double) - # s is the actual ordering - # s[i, j] is the value of k at which we split the product A_i..A_j - s = np.empty((n, n), dtype=np.intp) - - for l in range(1, n): - for i in range(n - l): - j = i + l - m[i, j] = jnp.inf - for k in range(i, j): - q = m[i, k] + m[k+1, j] + p[i]*p[k+1]*p[j+1] - if q < m[i, j]: - m[i, j] = q - s[i, j] = k # Note that Cormen uses 1-based index - - return (s, m) if return_costs else s - - -def _multi_dot(arrays, order, i, j, precision): - """Actually do the multiplication with the given order.""" - if i == j: - return arrays[i] - else: - return jnp.dot(_multi_dot(arrays, order, i, order[i, j], precision), - _multi_dot(arrays, order, order[i, j] + 1, j, precision), - precision=precision) diff --git a/jax/numpy/linalg.py b/jax/numpy/linalg.py index aff1f5b818b9..8a6db2bc476a 100644 --- a/jax/numpy/linalg.py +++ b/jax/numpy/linalg.py @@ -31,6 +31,7 @@ matrix_power as matrix_power, matrix_rank as matrix_rank, matrix_transpose as matrix_transpose, + multi_dot as multi_dot, norm as norm, outer as outer, pinv as pinv, @@ -47,5 +48,4 @@ ) from jax._src.third_party.numpy.linalg import ( cond as cond, - multi_dot as multi_dot, )