Skip to content

Commit

Permalink
[array API] add jnp.linalg.diagonal
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Jan 11, 2024
1 parent 35fc2ed commit b08a010
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 4 deletions.
1 change: 1 addition & 0 deletions docs/jax.numpy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,7 @@ jax.numpy.linalg
cond
cross
det
diagonal
eig
eigh
eigvals
Expand Down
10 changes: 10 additions & 0 deletions jax/_src/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,14 +746,24 @@ def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1) -> Array:

@_wraps(getattr(np.linalg, "matmul", None))
def matmul(x1: ArrayLike, x2: ArrayLike, /) -> Array:
check_arraylike('jnp.linalg.matmul', x1, x2)
return jnp.matmul(x1, x2)


@_wraps(getattr(np.linalg, "tensordot", None))
def tensordot(x1: ArrayLike, x2: ArrayLike, /, *,
axes: int | tuple[Sequence[int], Sequence[int]] = 2) -> Array:
check_arraylike('jnp.linalg.tensordot', x1, x2)
return jnp.tensordot(x1, x2, axes=axes)


@_wraps(getattr(np.linalg, "svdvals", None))
def svdvals(x: ArrayLike, /) -> Array:
check_arraylike('jnp.linalg.svdvals', x)
return svd(x, compute_uv=False, hermitian=False)


@_wraps(getattr(np.linalg, "diagonal", None))
def diagonal(x: ArrayLike, /, *, offset: int = 0) -> Array:
check_arraylike('jnp.linalg.diagonal', x)
return jnp.diagonal(x, offset=offset, axis1=-2, axis2=-1)
5 changes: 1 addition & 4 deletions jax/experimental/array_api/_linear_algebra_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,7 @@ def diagonal(x, /, *, offset=0):
"""
Returns the specified diagonals of a matrix (or a stack of matrices) x.
"""
f = partial(jax.numpy.diagonal, offset=offset)
for _ in range(x.ndim - 2):
f = jax.vmap(f)
return f(x)
return jax.numpy.linalg.diagonal(x, offset=offset)

def eigh(x, /):
"""
Expand Down
1 change: 1 addition & 0 deletions jax/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
cholesky as cholesky,
cross as cross,
det as det,
diagonal as diagonal,
eig as eig,
eigh as eigh,
eigvals as eigvals,
Expand Down
16 changes: 16 additions & 0 deletions tests/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1251,6 +1251,22 @@ def testOuter(self, lhs_shape, rhs_shape, lhs_dtype, rhs_dtype):
self._CheckAgainstNumpy(np_fun, lax_fun, args_maker)
self._CompileAndCheck(lax_fun, args_maker)

@jtu.sample_product(
shape = [(2, 3), (3, 2), (3, 3, 4), (4, 3, 3), (2, 3, 4, 5)],
dtype = jtu.dtypes.all,
offset=range(-2, 3)
)
def testDiagonal(self, shape, dtype, offset):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
lax_fun = partial(jnp.linalg.diagonal, offset=offset)
if jtu.numpy_version() >= (2, 0, 0):
np_fun = partial(np.linalg.diagonal, offset=offset)
else:
np_fun = partial(np.diagonal, offset=offset, axis1=-2, axis2=-1)
self._CheckAgainstNumpy(np_fun, lax_fun, args_maker)
self._CompileAndCheck(lax_fun, args_maker)


class ScipyLinalgTest(jtu.JaxTestCase):

Expand Down

0 comments on commit b08a010

Please sign in to comment.