Skip to content

Commit

Permalink
jnp.linalg: add matmul, tensordot, & svdvals
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Dec 19, 2023
1 parent c172be1 commit 0e9374b
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 3 deletions.
3 changes: 3 additions & 0 deletions docs/jax.numpy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,7 @@ jax.numpy.linalg
eigvalsh
inv
lstsq
matmul
matrix_norm
matrix_power
matrix_rank
Expand All @@ -468,6 +469,8 @@ jax.numpy.linalg
slogdet
solve
svd
svdvals
tensordot
tensorinv
tensorsolve
vector_norm
Expand Down
16 changes: 16 additions & 0 deletions jax/_src/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from __future__ import annotations

from collections.abc import Sequence
from functools import partial

import numpy as np
Expand Down Expand Up @@ -748,3 +749,18 @@ def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1) -> Array:
x2_arr = jax.numpy.moveaxis(x2_arr, axis, -1)
# TODO(jakevdp): call lax.dot_general directly
return jax.numpy.matmul(x1_arr[..., None, :], x2_arr[..., None])[..., 0, 0]


@_wraps(getattr(np.linalg, "matmul", None))
def matmul(x1: ArrayLike, x2: ArrayLike, /) -> Array:
return jnp.matmul(x1, x2)


@_wraps(getattr(np.linalg, "tensordot", None))
def tensordot(x1: ArrayLike, x2: ArrayLike, /, *,
axes: int | tuple[Sequence[int], Sequence[int]] = 2) -> Array:
return jnp.tensordot(x1, x2, axes=axes)

@_wraps(getattr(np.linalg, "svdvals", None))
def svdvals(x: ArrayLike, /) -> Array:
return svd(x, compute_uv=False, hermitian=False)
6 changes: 3 additions & 3 deletions jax/experimental/array_api/_linear_algebra_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def inv(x, /):

def matmul(x1, x2, /):
"""Computes the matrix product."""
return jax.numpy.matmul(x1, x2)
return jax.numpy.linalg.matmul(x1, x2)

def matrix_norm(x, /, *, keepdims=False, ord='fro'):
"""
Expand Down Expand Up @@ -160,11 +160,11 @@ def svdvals(x, /):
"""
Returns the singular values of a matrix (or a stack of matrices) x.
"""
return jax.numpy.linalg.svd(x, compute_uv=False)
return jax.numpy.linalg.svdvals(x)

def tensordot(x1, x2, /, *, axes=2):
"""Returns a tensor contraction of x1 and x2 over specific axes."""
return jax.numpy.tensordot(x1, x2, axes=axes)
return jax.numpy.linalg.tensordot(x1, x2, axes=axes)

def trace(x, /, *, offset=0, dtype=None):
"""
Expand Down
3 changes: 3 additions & 0 deletions jax/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
eigvalsh as eigvalsh,
inv as inv,
lstsq as lstsq,
matmul as matmul,
matrix_norm as matrix_norm,
matrix_power as matrix_power,
matrix_rank as matrix_rank,
Expand All @@ -36,6 +37,8 @@
slogdet as slogdet,
solve as solve,
svd as svd,
svdvals as svdvals,
tensordot as tensordot,
vector_norm as vector_norm,
vecdot as vecdot,
)
Expand Down
44 changes: 44 additions & 0 deletions tests/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,50 @@ def np_fn(x, y, axis=axis):
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=1e-3)
self._CompileAndCheck(jnp_fn, args_maker)

# jnp.linalg.matmul is an alias of jnp.matmul; do a minimal test here.
@jtu.sample_product(
[
dict(lhs_shape=(3,), rhs_shape=(3,)), # vec-vec
dict(lhs_shape=(2, 3), rhs_shape=(3,)), # mat-vec
dict(lhs_shape=(3,), rhs_shape=(3, 4)), # vec-mat
dict(lhs_shape=(2, 3), rhs_shape=(3, 4)), # mat-mat
],
dtype=float_types + complex_types
)
@jax.default_matmul_precision("float32")
def testMatmul(self, lhs_shape, rhs_shape, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
np_fn = jtu.promote_like_jnp(
np.matmul if jtu.numpy_version() < (2, 0, 0) else np.linalg.matmul)
jnp_fn = jnp.linalg.matmul
tol = {np.float16: 1e-2, np.float32: 2e-2, np.float64: 1e-12,
np.complex128: 1e-12}
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol)
self._CompileAndCheck(jnp_fn, args_maker, tol=tol)

# jnp.linalg.tensordot is an alias of jnp.tensordot; do a minimal test here.
@jtu.sample_product(
[
dict(lhs_shape=(2, 2, 2), rhs_shape=(2, 2), axes=0),
dict(lhs_shape=(2, 2, 2), rhs_shape=(2, 2), axes=1),
dict(lhs_shape=(2, 2, 2), rhs_shape=(2, 2), axes=2),
],
dtype=float_types + complex_types
)
@jax.default_matmul_precision("float32")
def testTensordot(self, lhs_shape, rhs_shape, axes, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
np_fn = jtu.promote_like_jnp(
partial(
np.tensordot if jtu.numpy_version() < (2, 0, 0) else np.linalg.tensordot,
axes=axes))
jnp_fn = partial(jnp.linalg.tensordot, axes=axes)
tol = {np.float16: 1e-2, np.float32: 2e-2, np.float64: 1e-12,
np.complex128: 1e-12}
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol)
self._CompileAndCheck(jnp_fn, args_maker, tol=tol)

@jtu.sample_product(
[
Expand Down
16 changes: 16 additions & 0 deletions tests/svd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,22 @@
@jtu.with_config(jax_numpy_rank_promotion='allow')
class SvdTest(jtu.JaxTestCase):

@jtu.sample_product(
shape=[(4, 5), (3, 4, 5), (2, 3, 4, 5)],
dtype=jtu.dtypes.floating + jtu.dtypes.complex,
)
@jax.default_matmul_precision('float32')
def testSvdvals(self, shape, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
jnp_fun = jax.numpy.linalg.svdvals
if jtu.numpy_version() < (2, 0, 0):
np_fun = lambda x: np.linalg.svd(x, compute_uv=False)
else:
np_fun = np.linalg.svdvals
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=_SVD_RTOL)
self._CompileAndCheck(jnp_fun, args_maker, rtol=_SVD_RTOL)

@jtu.sample_product(
[dict(m=m, n=n) for m, n in zip([2, 8, 10, 20], [4, 6, 10, 18])],
log_cond=np.linspace(1, _MAX_LOG_CONDITION_NUM, 4),
Expand Down

0 comments on commit 0e9374b

Please sign in to comment.