Skip to content

Commit

Permalink
Implement numpy.linalg.multi_dot (jax-ml#2726)
Browse files Browse the repository at this point in the history
* Implement numpy.linalg.multi_dot

* Thread precision through multi_dot
  • Loading branch information
jakevdp authored and jacobjinkelly committed Apr 21, 2020
1 parent 279336e commit 3b69c41
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/jax.numpy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ jax.numpy.linalg
inv
matrix_power
matrix_rank
multi_dot
norm
pinv
qr
Expand Down
2 changes: 1 addition & 1 deletion jax/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from .vectorize import vectorize
from . import lax_numpy as np
from ..util import get_module_functions
from ..third_party.numpy.linalg import cond, tensorinv, tensorsolve
from ..third_party.numpy.linalg import cond, multi_dot, tensorinv, tensorsolve

_T = lambda x: np.swapaxes(x, -1, -2)

Expand Down
109 changes: 109 additions & 0 deletions jax/third_party/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@ def _assertNdSquareness(*arrays):
'Last 2 dimensions of the array must be square')


def _assert2d(*arrays):
for a in arrays:
if a.ndim != 2:
raise ValueError(f'{a.ndim}-dimensional array given. '
'Array must be two-dimensional')


@_wraps(onp.linalg.cond)
def cond(x, p=None):
_assertNoEmpty2d(x)
Expand Down Expand Up @@ -97,3 +104,105 @@ def tensorsolve(a, b, axes=None):
res = res.reshape(Q)

return res


@_wraps(onp.linalg.multi_dot)
def multi_dot(arrays, *, precision=None):
n = len(arrays)
# optimization only makes sense for len(arrays) > 2
if n < 2:
raise ValueError("Expecting at least two arrays.")
elif n == 2:
return np.dot(arrays[0], arrays[1], precision=precision)

arrays = [np.asarray(a) for a in arrays]

# save original ndim to reshape the result array into the proper form later
ndim_first, ndim_last = arrays[0].ndim, arrays[-1].ndim
# Explicitly convert vectors to 2D arrays to keep the logic of the internal
# _multi_dot_* functions as simple as possible.
if arrays[0].ndim == 1:
arrays[0] = np.atleast_2d(arrays[0])
if arrays[-1].ndim == 1:
arrays[-1] = np.atleast_2d(arrays[-1]).T
_assert2d(*arrays)

# _multi_dot_three is much faster than _multi_dot_matrix_chain_order
if n == 3:
result = _multi_dot_three(*arrays, precision)
else:
order = _multi_dot_matrix_chain_order(arrays)
result = _multi_dot(arrays, order, 0, n - 1, precision)

# return proper shape
if ndim_first == 1 and ndim_last == 1:
return result[0, 0] # scalar
elif ndim_first == 1 or ndim_last == 1:
return result.ravel() # 1-D
else:
return result


def _multi_dot_three(A, B, C, precision):
"""
Find the best order for three arrays and do the multiplication.
For three arguments `_multi_dot_three` is approximately 15 times faster
than `_multi_dot_matrix_chain_order`
"""
a0, a1b0 = A.shape
b1c0, c1 = C.shape
# cost1 = cost((AB)C) = a0*a1b0*b1c0 + a0*b1c0*c1
cost1 = a0 * b1c0 * (a1b0 + c1)
# cost2 = cost(A(BC)) = a1b0*b1c0*c1 + a0*a1b0*c1
cost2 = a1b0 * c1 * (a0 + b1c0)

if cost1 < cost2:
return np.dot(np.dot(A, B, precision=precision), C, precision=precision)
else:
return np.dot(A, np.dot(B, C, precision=precision), precision=precision)


def _multi_dot_matrix_chain_order(arrays, return_costs=False):
"""
Return a np.array that encodes the optimal order of mutiplications.
The optimal order array is then used by `_multi_dot()` to do the
multiplication.
Also return the cost matrix if `return_costs` is `True`
The implementation CLOSELY follows Cormen, "Introduction to Algorithms",
Chapter 15.2, p. 370-378. Note that Cormen uses 1-based indices.
cost[i, j] = min([
cost[prefix] + cost[suffix] + cost_mult(prefix, suffix)
for k in range(i, j)])
"""
n = len(arrays)
# p stores the dimensions of the matrices
# Example for p: A_{10x100}, B_{100x5}, C_{5x50} --> p = [10, 100, 5, 50]
p = [a.shape[0] for a in arrays] + [arrays[-1].shape[1]]
# m is a matrix of costs of the subproblems
# m[i,j]: min number of scalar multiplications needed to compute A_{i..j}
m = onp.zeros((n, n), dtype=onp.double)
# s is the actual ordering
# s[i, j] is the value of k at which we split the product A_i..A_j
s = onp.empty((n, n), dtype=onp.intp)

for l in range(1, n):
for i in range(n - l):
j = i + l
m[i, j] = np.inf
for k in range(i, j):
q = m[i, k] + m[k+1, j] + p[i]*p[k+1]*p[j+1]
if q < m[i, j]:
m[i, j] = q
s[i, j] = k # Note that Cormen uses 1-based index

return (s, m) if return_costs else s


def _multi_dot(arrays, order, i, j, precision):
"""Actually do the multiplication with the given order."""
if i == j:
return arrays[i]
else:
return np.dot(_multi_dot(arrays, order, i, order[i, j], precision),
_multi_dot(arrays, order, order[i, j] + 1, j, precision),
precision=precision)
26 changes: 26 additions & 0 deletions tests/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,32 @@ def testMatrixRank(self, shape, dtype, rng_factory):
self._CompileAndCheck(np.linalg.matrix_rank, args_maker,
check_dtypes=False, rtol=1e-3)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shapes={}".format(
','.join(jtu.format_shape_dtype_string(s, dtype) for s in shapes)),
"shapes": shapes, "dtype": dtype, "rng_factory": rng_factory}
for shapes in [
[(3, ), (3, 1)], # quick-out codepath
[(1, 3), (3, 5), (5, 2)], # multi_dot_three codepath
[(1, 3), (3, 5), (5, 2), (2, 7), (7, )] # dynamic programming codepath
]
for dtype in float_types + complex_types
for rng_factory in [jtu.rand_default]))
def testMultiDot(self, shapes, dtype, rng_factory):
rng = rng_factory()
_skip_if_unsupported_type(dtype)
args_maker = lambda: [[rng(shape, dtype) for shape in shapes]]

onp_fun = onp.linalg.multi_dot
jnp_fun = partial(np.linalg.multi_dot, precision=lax.Precision.HIGHEST)
tol = {onp.float32: 1e-4, onp.float64: 1e-10,
onp.complex64: 1e-4, onp.complex128: 1e-10}

self._CheckAgainstNumpy(onp_fun, jnp_fun, args_maker, check_dtypes=True,
tol=tol)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True,
atol=tol, rtol=tol)

# Regression test for incorrect type for eigenvalues of a complex matrix.
@jtu.skip_on_devices("tpu") # TODO(phawkins): No complex eigh implementation on TPU.
def testIssue669(self):
Expand Down

0 comments on commit 3b69c41

Please sign in to comment.