Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement numpy.linalg.multi_dot #2726

Merged
merged 2 commits into from
Apr 16, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/jax.numpy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ jax.numpy.linalg
inv
matrix_power
matrix_rank
multi_dot
norm
pinv
qr
Expand Down
2 changes: 1 addition & 1 deletion jax/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from .vectorize import vectorize
from . import lax_numpy as np
from ..util import get_module_functions
from ..third_party.numpy.linalg import cond, tensorinv, tensorsolve
from ..third_party.numpy.linalg import cond, multi_dot, tensorinv, tensorsolve

_T = lambda x: np.swapaxes(x, -1, -2)

Expand Down
108 changes: 108 additions & 0 deletions jax/third_party/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@ def _assertNdSquareness(*arrays):
'Last 2 dimensions of the array must be square')


def _assert2d(*arrays):
for a in arrays:
if a.ndim != 2:
raise ValueError(f'{a.ndim}-dimensional array given. '
'Array must be two-dimensional')


@_wraps(onp.linalg.cond)
def cond(x, p=None):
_assertNoEmpty2d(x)
Expand Down Expand Up @@ -97,3 +104,104 @@ def tensorsolve(a, b, axes=None):
res = res.reshape(Q)

return res


@_wraps(onp.linalg.multi_dot)
def multi_dot(arrays):
n = len(arrays)
# optimization only makes sense for len(arrays) > 2
if n < 2:
raise ValueError("Expecting at least two arrays.")
elif n == 2:
return np.dot(arrays[0], arrays[1])

arrays = [np.asarray(a) for a in arrays]

# save original ndim to reshape the result array into the proper form later
ndim_first, ndim_last = arrays[0].ndim, arrays[-1].ndim
# Explicitly convert vectors to 2D arrays to keep the logic of the internal
# _multi_dot_* functions as simple as possible.
if arrays[0].ndim == 1:
arrays[0] = np.atleast_2d(arrays[0])
if arrays[-1].ndim == 1:
arrays[-1] = np.atleast_2d(arrays[-1]).T
_assert2d(*arrays)

# _multi_dot_three is much faster than _multi_dot_matrix_chain_order
if n == 3:
result = _multi_dot_three(*arrays)
else:
order = _multi_dot_matrix_chain_order(arrays)
result = _multi_dot(arrays, order, 0, n - 1)

# return proper shape
if ndim_first == 1 and ndim_last == 1:
return result[0, 0] # scalar
elif ndim_first == 1 or ndim_last == 1:
return result.ravel() # 1-D
else:
return result


def _multi_dot_three(A, B, C):
"""
Find the best order for three arrays and do the multiplication.
For three arguments `_multi_dot_three` is approximately 15 times faster
than `_multi_dot_matrix_chain_order`
"""
a0, a1b0 = A.shape
b1c0, c1 = C.shape
# cost1 = cost((AB)C) = a0*a1b0*b1c0 + a0*b1c0*c1
cost1 = a0 * b1c0 * (a1b0 + c1)
# cost2 = cost(A(BC)) = a1b0*b1c0*c1 + a0*a1b0*c1
cost2 = a1b0 * c1 * (a0 + b1c0)

if cost1 < cost2:
return np.dot(np.dot(A, B), C)
else:
return np.dot(A, np.dot(B, C))


def _multi_dot_matrix_chain_order(arrays, return_costs=False):
"""
Return a np.array that encodes the optimal order of mutiplications.
The optimal order array is then used by `_multi_dot()` to do the
multiplication.
Also return the cost matrix if `return_costs` is `True`
The implementation CLOSELY follows Cormen, "Introduction to Algorithms",
Chapter 15.2, p. 370-378. Note that Cormen uses 1-based indices.
cost[i, j] = min([
cost[prefix] + cost[suffix] + cost_mult(prefix, suffix)
for k in range(i, j)])
"""
n = len(arrays)
# p stores the dimensions of the matrices
# Example for p: A_{10x100}, B_{100x5}, C_{5x50} --> p = [10, 100, 5, 50]
p = [a.shape[0] for a in arrays] + [arrays[-1].shape[1]]
# m is a matrix of costs of the subproblems
# m[i,j]: min number of scalar multiplications needed to compute A_{i..j}
m = onp.zeros((n, n), dtype=onp.double)
# s is the actual ordering
# s[i, j] is the value of k at which we split the product A_i..A_j
s = onp.empty((n, n), dtype=onp.intp)

for l in range(1, n):
for i in range(n - l):
j = i + l
m[i, j] = np.inf
for k in range(i, j):
q = m[i, k] + m[k+1, j] + p[i]*p[k+1]*p[j+1]
if q < m[i, j]:
m[i, j] = q
s[i, j] = k # Note that Cormen uses 1-based index

return (s, m) if return_costs else s


def _multi_dot(arrays, order, i, j):
"""Actually do the multiplication with the given order."""
if i == j:
return arrays[i]
else:
return np.dot(_multi_dot(arrays, order, i, order[i, j]),
_multi_dot(arrays, order, order[i, j] + 1, j))
28 changes: 28 additions & 0 deletions tests/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,34 @@ def testMatrixRank(self, shape, dtype, rng_factory):
self._CompileAndCheck(np.linalg.matrix_rank, args_maker,
check_dtypes=False, rtol=1e-3)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shapes={}".format(
','.join(jtu.format_shape_dtype_string(s, dtype) for s in shapes)),
"shapes": shapes, "dtype": dtype, "rng_factory": rng_factory}
for shapes in [
[(3, ), (3, 1)], # quick-out codepath
[(1, 3), (3, 5), (5, 2)], # multi_dot_three codepath
[(1, 3), (3, 5), (5, 2), (2, 7), (7, )] # dynamic programming codepath
]
for dtype in float_types + complex_types
for rng_factory in [jtu.rand_default]))
def testMultiDot(self, shapes, dtype, rng_factory):
rng = rng_factory()
_skip_if_unsupported_type(dtype)
args_maker = lambda: [[rng(shape, dtype) for shape in shapes]]

tol = {onp.float32: 1e-4, onp.float64: 1e-10,
onp.complex64: 1e-4, onp.complex128: 1e-10}
if jtu.device_under_test() == "tpu":
tol[onp.float32] = tol[onp.complex64] = 1e-2
jakevdp marked this conversation as resolved.
Show resolved Hide resolved
tol[onp.float64] = tol[onp.complex128] = 1e-4

self._CheckAgainstNumpy(onp.linalg.multi_dot, np.linalg.multi_dot,
args_maker, check_dtypes=True,
tol=tol)
self._CompileAndCheck(np.linalg.multi_dot, args_maker, check_dtypes=True,
atol=tol, rtol=tol)

# Regression test for incorrect type for eigenvalues of a complex matrix.
@jtu.skip_on_devices("tpu") # TODO(phawkins): No complex eigh implementation on TPU.
def testIssue669(self):
Expand Down