Skip to content

Commit

Permalink
Add det/slogdet workaround for dask
Browse files Browse the repository at this point in the history
  • Loading branch information
lithomas1 committed Apr 23, 2024
1 parent 376038e commit 980129b
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 4 deletions.
72 changes: 72 additions & 0 deletions array_api_compat/dask/array/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,78 @@ def svdvals(x: Array) -> Array:
vector_norm = get_xp(da)(_linalg.vector_norm)
diagonal = get_xp(da)(_linalg.diagonal)

# Calculate determinant via PLU decomp
def det(x: Array) -> Array:
import scipy.linalg

# L has det 1 so don't need to worry about it
p, _, u = da.linalg.lu(x)

# TODO: numerical stability?
u_det = da.prod(da.diag(u))

# Now, time to calculate determinant of p

# (from reading the source code)
# We know that dask lu decomp forces square chunks
# We also know that the P matrix will only be non-zero
# for a block i, j if and only if i = j

# So we will calculate the determinant of each block on
# the diagonal (of blocks)

# This isn't ideal, but hopefully still lets out of core work
# properly since each block should be able to fit in memory

blocks_shape = p.blocks.shape
n_row_blocks = blocks_shape[0]

p_det = 1
for i in range(n_row_blocks):
p_det *= scipy.linalg.det(p.blocks[i, i].compute())
return p_det * u_det

SlogdetResult = _linalg.SlogdetResult

# Calculate determinant via PLU decomp
def slogdet(x: Array) -> Array:
import scipy.linalg

# L has det 1 so don't need to worry about it
p, _, u = da.linalg.lu(x)

u_diag = da.diag(u)
neg_cnt = (u_diag < 0).sum()

u_logabsdet = da.sum(da.log(da.abs(u_diag)))

# Now, time to calculate determinant of p

# (from reading the source code)
# We know that dask lu decomp forces square chunks
# We also know that the P matrix will only be non-zero
# for a block i, j if and only if i = j

# So we will calculate the determinant of each block on
# the diagonal (of blocks)

# This isn't ideal, but hopefully still lets out of core work
# properly since each block should be able to fit in memory

blocks_shape = p.blocks.shape
n_row_blocks = blocks_shape[0]

sign = 1
for i in range(n_row_blocks):
sign *= scipy.linalg.det(p.blocks[i, i].compute())

if neg_cnt % 2 != 0:
sign *= -1
return SlogdetResult(sign, u_logabsdet)




__all__ = linalg_all + ["trace", "outer", "matmul", "tensordot",
"matrix_transpose", "vecdot", "EighResult",
"QRResult", "SlogdetResult", "SVDResult", "qr",
Expand Down
7 changes: 3 additions & 4 deletions dask-xfails.txt
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ array_api_tests/test_linalg.py::test_cholesky
array_api_tests/test_linalg.py::test_tensordot
# probably same reason for failing as numpy
array_api_tests/test_linalg.py::test_trace
# our version depends on dask's LU, which doesn't support ndim > 2
array_api_tests/test_linalg.py::test_det
array_api_tests/test_linalg.py::test_slogdet

# AssertionError: out.dtype=uint64, but should be uint8 [tensordot(uint8, uint8)]
array_api_tests/test_linalg.py::test_linalg_tensordot
Expand All @@ -97,18 +100,14 @@ array_api_tests/test_linalg.py::test_linalg_matmul

# Linalg - these don't exist in dask
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.cross]
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.det]
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.eigh]
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.eigvalsh]
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.matrix_power]
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.pinv]
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.slogdet]
array_api_tests/test_linalg.py::test_cross
array_api_tests/test_linalg.py::test_det
array_api_tests/test_linalg.py::test_eigh
array_api_tests/test_linalg.py::test_eigvalsh
array_api_tests/test_linalg.py::test_pinv
array_api_tests/test_linalg.py::test_slogdet
array_api_tests/test_has_names.py::test_has_names[linalg-cross]
array_api_tests/test_has_names.py::test_has_names[linalg-det]
array_api_tests/test_has_names.py::test_has_names[linalg-eigh]
Expand Down

0 comments on commit 980129b

Please sign in to comment.