Skip to content

Commit

Permalink
Fix the pinv function, which was implicitly using __array__
Browse files Browse the repository at this point in the history
  • Loading branch information
asmeurer committed Oct 31, 2024
1 parent bb28167 commit d630ee5
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions array_api_strict/_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,8 @@ def pinv(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> Array:
# default tolerance by max(M, N).
if rtol is None:
rtol = max(x.shape[-2:]) * finfo(x.dtype).eps
if isinstance(rtol, Array):
rtol = rtol._array
return Array._new(np.linalg.pinv(x._array, rcond=rtol), device=x.device)

@requires_extension('linalg')
Expand Down

0 comments on commit d630ee5

Please sign in to comment.