Skip to content

Commit

Permalink
Merge pull request #21115 from jakevdp:multi-dot
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 631545161
  • Loading branch information
jax authors committed May 7, 2024
2 parents 2b3251e + 09810be commit 78e10ee
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 111 deletions.
2 changes: 2 additions & 0 deletions docs/aot.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
(ahead-of-time-lowering)=

# Ahead-of-time lowering and compilation

JAX offers several transformations, such as `jax.jit` and `jax.pmap`, returning
Expand Down
94 changes: 94 additions & 0 deletions jax/_src/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from collections.abc import Sequence
from functools import partial
import itertools
import math
import warnings

Expand All @@ -28,6 +29,7 @@
from jax import lax

from jax._src.lax import lax as lax_internal
from jax._src.lax.lax import PrecisionLike
from jax._src.lax import linalg as lax_linalg
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy import reductions, ufuncs
Expand Down Expand Up @@ -1924,3 +1926,95 @@ def tensorsolve(a: ArrayLike, b: ArrayLike, axes: tuple[int, ...] | None = None)
f" got a.shape={a_arr.shape}, b.ndim={b_arr.ndim}.")
a_arr = a_arr.reshape(b_arr.size, math.prod(out_shape))
return solve(a_arr, b_arr.ravel()).reshape(out_shape)


def multi_dot(arrays: Sequence[ArrayLike], *, precision: PrecisionLike = None) -> Array:
"""Efficiently compute matrix products between a sequence of arrays.
JAX implementation of :func:`numpy.linalg.multi_dot`.
JAX internally uses the opt_einsum library to compute the most efficient
operation order.
Args:
arrays: sequence of arrays. All must be two-dimensional, except the first
and last which may be one-dimensional.
precision: either ``None`` (default), which means the default precision for
the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``,
``Precision.HIGH`` or ``Precision.HIGHEST``).
Returns:
an array representing the equivalent of ``reduce(jnp.matmul, arrays)``, but
evaluated in the optimal order.
This function exists because the cost of computing sequences of matmul operations
can differ vastly depending on the order in which the operations are evaluated.
For a single matmul, the number of floating point operations (flops) required to
compute a matrix product can be approximated this way:
>>> def approx_flops(x, y):
... # for 2D x and y, with x.shape[1] == y.shape[0]
... return 2 * x.shape[0] * x.shape[1] * y.shape[1]
Suppose we have three matrices that we'd like to multiply in sequence:
>>> key1, key2, key3 = jax.random.split(jax.random.key(0), 3)
>>> x = jax.random.normal(key1, shape=(200, 5))
>>> y = jax.random.normal(key2, shape=(5, 100))
>>> z = jax.random.normal(key3, shape=(100, 10))
Because of associativity of matrix products, there are two orders in which we might
evaluate the product ``x @ y @ z``, and both produce equivalent outputs up to floating
point precision:
>>> result1 = (x @ y) @ z
>>> result2 = x @ (y @ z)
>>> jnp.allclose(result1, result2, atol=1E-4)
Array(True, dtype=bool)
But the computational cost of these differ greatly:
>>> print("(x @ y) @ z flops:", approx_flops(x, y) + approx_flops(x @ y, z))
(x @ y) @ z flops: 600000
>>> print("x @ (y @ z) flops:", approx_flops(y, z) + approx_flops(x, y @ z))
x @ (y @ z) flops: 30000
The second approach is about 20x more efficient in terms of estimated flops!
``multi_dot`` is a function that will automatically choose the fastest
computational path for such problems:
>>> result3 = jnp.linalg.multi_dot([x, y, z])
>>> jnp.allclose(result1, result3, atol=1E-4)
Array(True, dtype=bool)
We can use JAX's :ref:`ahead-of-time-lowering` tools to estimate the total flops
of each approach, and confirm that ``multi_dot`` is choosing the more efficient
option:
>>> jax.jit(lambda x, y, z: (x @ y) @ z).lower(x, y, z).cost_analysis()['flops']
600000.0
>>> jax.jit(lambda x, y, z: x @ (y @ z)).lower(x, y, z).cost_analysis()['flops']
30000.0
>>> jax.jit(jnp.linalg.multi_dot).lower([x, y, z]).cost_analysis()['flops']
30000.0
"""
check_arraylike('jnp.linalg.multi_dot', *arrays)
arrs: list[Array] = list(map(jnp.asarray, arrays))
if len(arrs) < 2:
raise ValueError(f"multi_dot requires at least two arrays; got len(arrays)={len(arrs)}")
if not (arrs[0].ndim in (1, 2) and arrs[-1].ndim in (1, 2) and
all(a.ndim == 2 for a in arrs[1:-1])):
raise ValueError("multi_dot: input arrays must all be two-dimensional, except for"
" the first and last array which may be 1 or 2 dimensional."
f" Got array shapes {[a.shape for a in arrs]}")
if any(a.shape[-1] != b.shape[0] for a, b in zip(arrs[:-1], arrs[1:])):
raise ValueError("multi_dot: last dimension of each array must match first dimension"
f" of following array. Got array shapes {[a.shape for a in arrs]}")
einsum_axes: list[tuple[int, ...]] = [(i, i+1) for i in range(len(arrs))]
if arrs[0].ndim == 1:
einsum_axes[0] = einsum_axes[0][1:]
if arrs[-1].ndim == 1:
einsum_axes[-1] = einsum_axes[-1][:1]
return jnp.einsum(*itertools.chain(*zip(arrs, einsum_axes)), # type: ignore[arg-type, call-overload]
optimize='optimal', precision=precision)
110 changes: 0 additions & 110 deletions jax/_src/third_party/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,6 @@ def _assertNdSquareness(*arrays):
'Last 2 dimensions of the array must be square')


def _assert2d(*arrays):
for a in arrays:
if a.ndim != 2:
raise ValueError(f'{a.ndim}-dimensional array given. '
'Array must be two-dimensional')


@implements(np.linalg.cond)
def cond(x, p=None):
check_arraylike('jnp.linalg.cond', x)
Expand All @@ -60,106 +53,3 @@ def cond(x, p=None):
nan_mask = jnp.logical_and(jnp.isnan(r), ~jnp.isnan(x).any(axis=(-2, -1)))
r = jnp.where(orig_nan_check, jnp.where(nan_mask, jnp.inf, r), r)
return r


@implements(np.linalg.multi_dot)
def multi_dot(arrays, *, precision=None):
check_arraylike('jnp.linalg.multi_dot', *arrays)
n = len(arrays)
# optimization only makes sense for len(arrays) > 2
if n < 2:
raise ValueError("Expecting at least two arrays.")
elif n == 2:
return jnp.dot(arrays[0], arrays[1], precision=precision)

arrays = [jnp.asarray(a) for a in arrays]

# save original ndim to reshape the result array into the proper form later
ndim_first, ndim_last = arrays[0].ndim, arrays[-1].ndim
# Explicitly convert vectors to 2D arrays to keep the logic of the internal
# _multi_dot_* functions as simple as possible.
if arrays[0].ndim == 1:
arrays[0] = jnp.atleast_2d(arrays[0])
if arrays[-1].ndim == 1:
arrays[-1] = jnp.atleast_2d(arrays[-1]).T
_assert2d(*arrays)

# _multi_dot_three is much faster than _multi_dot_matrix_chain_order
if n == 3:
result = _multi_dot_three(*arrays, precision)
else:
order = _multi_dot_matrix_chain_order(arrays)
result = _multi_dot(arrays, order, 0, n - 1, precision)

# return proper shape
if ndim_first == 1 and ndim_last == 1:
return result[0, 0] # scalar
elif ndim_first == 1 or ndim_last == 1:
return result.ravel() # 1-D
else:
return result


def _multi_dot_three(A, B, C, precision):
"""
Find the best order for three arrays and do the multiplication.
For three arguments `_multi_dot_three` is approximately 15 times faster
than `_multi_dot_matrix_chain_order`
"""
a0, a1b0 = A.shape
b1c0, c1 = C.shape
# cost1 = cost((AB)C) = a0*a1b0*b1c0 + a0*b1c0*c1
cost1 = a0 * b1c0 * (a1b0 + c1)
# cost2 = cost(A(BC)) = a1b0*b1c0*c1 + a0*a1b0*c1
cost2 = a1b0 * c1 * (a0 + b1c0)

if cost1 < cost2:
return jnp.dot(jnp.dot(A, B, precision=precision), C, precision=precision)
else:
return jnp.dot(A, jnp.dot(B, C, precision=precision), precision=precision)


def _multi_dot_matrix_chain_order(arrays, return_costs=False):
"""
Return a jnp.array that encodes the optimal order of mutiplications.
The optimal order array is then used by `_multi_dot()` to do the
multiplication.
Also return the cost matrix if `return_costs` is `True`
The implementation CLOSELY follows Cormen, "Introduction to Algorithms",
Chapter 15.2, p. 370-378. Note that Cormen uses 1-based indices.
cost[i, j] = min([
cost[prefix] + cost[suffix] + cost_mult(prefix, suffix)
for k in range(i, j)])
"""
n = len(arrays)
# p stores the dimensions of the matrices
# Example for p: A_{10x100}, B_{100x5}, C_{5x50} --> p = [10, 100, 5, 50]
p = [a.shape[0] for a in arrays] + [arrays[-1].shape[1]]
# m is a matrix of costs of the subproblems
# m[i,j]: min number of scalar multiplications needed to compute A_{i..j}
m = np.zeros((n, n), dtype=np.double)
# s is the actual ordering
# s[i, j] is the value of k at which we split the product A_i..A_j
s = np.empty((n, n), dtype=np.intp)

for l in range(1, n):
for i in range(n - l):
j = i + l
m[i, j] = jnp.inf
for k in range(i, j):
q = m[i, k] + m[k+1, j] + p[i]*p[k+1]*p[j+1]
if q < m[i, j]:
m[i, j] = q
s[i, j] = k # Note that Cormen uses 1-based index

return (s, m) if return_costs else s


def _multi_dot(arrays, order, i, j, precision):
"""Actually do the multiplication with the given order."""
if i == j:
return arrays[i]
else:
return jnp.dot(_multi_dot(arrays, order, i, order[i, j], precision),
_multi_dot(arrays, order, order[i, j] + 1, j, precision),
precision=precision)
2 changes: 1 addition & 1 deletion jax/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
matrix_power as matrix_power,
matrix_rank as matrix_rank,
matrix_transpose as matrix_transpose,
multi_dot as multi_dot,
norm as norm,
outer as outer,
pinv as pinv,
Expand All @@ -47,5 +48,4 @@
)
from jax._src.third_party.numpy.linalg import (
cond as cond,
multi_dot as multi_dot,
)

0 comments on commit 78e10ee

Please sign in to comment.