Skip to content

Commit

Permalink
Merge pull request #18833 from jakevdp:linalg-shapes
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 588458694
  • Loading branch information
jax authors committed Dec 6, 2023
2 parents 4bdcb11 + 45905fa commit 9e3a8fa
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 0 deletions.
2 changes: 2 additions & 0 deletions jax/_src/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,8 @@ def qr(a: ArrayLike, mode: str = "reduced") -> Union[Array, tuple[Array, Array]]
def solve(a: ArrayLike, b: ArrayLike) -> Array:
check_arraylike("jnp.linalg.solve", a, b)
a, b = promote_dtypes_inexact(jnp.asarray(a), jnp.asarray(b))
if a.ndim >= 2 and b.ndim > a.ndim:
a = lax.expand_dims(a, tuple(range(b.ndim - a.ndim)))
return lax_linalg._solve(a, b)


Expand Down
1 change: 1 addition & 0 deletions tests/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,6 +897,7 @@ def tensor_maker():
((4, 4), (4,)),
((8, 8), (8, 4)),
((1, 2, 2), (3, 2)),
((2, 2), (3, 2, 2)),
((2, 1, 3, 3), (1, 4, 3, 4)),
((1, 0, 0), (1, 0, 2)),
]
Expand Down

0 comments on commit 9e3a8fa

Please sign in to comment.