diff --git a/docs/jax.numpy.rst b/docs/jax.numpy.rst index 026f6aa3e332..adb8131a49d7 100644 --- a/docs/jax.numpy.rst +++ b/docs/jax.numpy.rst @@ -466,6 +466,7 @@ jax.numpy.linalg cond cross det + diagonal eig eigh eigvals diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index 5e50b7ef13b5..617ecfd314fc 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -746,14 +746,24 @@ def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1) -> Array: @_wraps(getattr(np.linalg, "matmul", None)) def matmul(x1: ArrayLike, x2: ArrayLike, /) -> Array: + check_arraylike('jnp.linalg.matmul', x1, x2) return jnp.matmul(x1, x2) @_wraps(getattr(np.linalg, "tensordot", None)) def tensordot(x1: ArrayLike, x2: ArrayLike, /, *, axes: int | tuple[Sequence[int], Sequence[int]] = 2) -> Array: + check_arraylike('jnp.linalg.tensordot', x1, x2) return jnp.tensordot(x1, x2, axes=axes) + @_wraps(getattr(np.linalg, "svdvals", None)) def svdvals(x: ArrayLike, /) -> Array: + check_arraylike('jnp.linalg.svdvals', x) return svd(x, compute_uv=False, hermitian=False) + + +@_wraps(getattr(np.linalg, "diagonal", None)) +def diagonal(x: ArrayLike, /, *, offset: int = 0) -> Array: + check_arraylike('jnp.linalg.diagonal', x) + return jnp.diagonal(x, offset=offset, axis1=-2, axis2=-1) diff --git a/jax/experimental/array_api/_linear_algebra_functions.py b/jax/experimental/array_api/_linear_algebra_functions.py index 4703c3dfb997..c66b0f246927 100644 --- a/jax/experimental/array_api/_linear_algebra_functions.py +++ b/jax/experimental/array_api/_linear_algebra_functions.py @@ -59,10 +59,7 @@ def diagonal(x, /, *, offset=0): """ Returns the specified diagonals of a matrix (or a stack of matrices) x. """ - f = partial(jax.numpy.diagonal, offset=offset) - for _ in range(x.ndim - 2): - f = jax.vmap(f) - return f(x) + return jax.numpy.linalg.diagonal(x, offset=offset) def eigh(x, /): """ diff --git a/jax/numpy/linalg.py b/jax/numpy/linalg.py index 5d8f716da7b4..a4e65fc32fb9 100644 --- a/jax/numpy/linalg.py +++ b/jax/numpy/linalg.py @@ -19,6 +19,7 @@ cholesky as cholesky, cross as cross, det as det, + diagonal as diagonal, eig as eig, eigh as eigh, eigvals as eigvals, diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 92c1d6b0288c..ec799906ee7d 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -1251,6 +1251,22 @@ def testOuter(self, lhs_shape, rhs_shape, lhs_dtype, rhs_dtype): self._CheckAgainstNumpy(np_fun, lax_fun, args_maker) self._CompileAndCheck(lax_fun, args_maker) + @jtu.sample_product( + shape = [(2, 3), (3, 2), (3, 3, 4), (4, 3, 3), (2, 3, 4, 5)], + dtype = jtu.dtypes.all, + offset=range(-2, 3) + ) + def testDiagonal(self, shape, dtype, offset): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + lax_fun = partial(jnp.linalg.diagonal, offset=offset) + if jtu.numpy_version() >= (2, 0, 0): + np_fun = partial(np.linalg.diagonal, offset=offset) + else: + np_fun = partial(np.diagonal, offset=offset, axis1=-2, axis2=-1) + self._CheckAgainstNumpy(np_fun, lax_fun, args_maker) + self._CompileAndCheck(lax_fun, args_maker) + class ScipyLinalgTest(jtu.JaxTestCase):