Skip to content

Commit

Permalink
Allow no-copy construction from SciPy COO arrays. (#822)
Browse files Browse the repository at this point in the history
* Hold reference to converted scipy.sparse.coo_*.

* Allow comparison against NumPy dtypes.
  • Loading branch information
hameerabbasi authored Dec 3, 2024
1 parent 13dda8e commit 128a567
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 10 deletions.
15 changes: 8 additions & 7 deletions sparse/mlir_backend/_conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ def _from_scipy(arr: ScipySparseArray, copy: bool | None = None) -> Array:

return from_constituent_arrays(format=csx_format, arrays=(indptr, indices, data), shape=arr.shape)
case "coo":
if copy is not None and not copy:
raise RuntimeError(f"`scipy.sparse.{type(arr.__name__)}` cannot be zero-copy converted.")
from ._common import _hold_ref

row, col = arr.row, arr.col
if row.dtype != col.dtype:
raise RuntimeError(f"`row` and `col` dtypes must be the same: {row.dtype} != {col.dtype}.")
Expand All @@ -89,10 +89,8 @@ def _from_scipy(arr: ScipySparseArray, copy: bool | None = None) -> Array:
data = arr.data
if copy:
data = data.copy()

# TODO: Make them own the data until https://github.com/llvm/llvm-project/issues/116012 is fixed.
row = row.copy()
col = col.copy()
row = row.copy()
col = col.copy()

coo_format = (
Coo()
Expand All @@ -103,7 +101,10 @@ def _from_scipy(arr: ScipySparseArray, copy: bool | None = None) -> Array:
.build()
)

return from_constituent_arrays(format=coo_format, arrays=(pos, row, col, data), shape=arr.shape)
ret = from_constituent_arrays(format=coo_format, arrays=(pos, row, col, data), shape=arr.shape)
if not copy:
_hold_ref(ret, arr)
return ret
case _:
raise NotImplementedError(f"No conversion implemented for `scipy.sparse.{type(arr.__name__)}`.")

Expand Down
5 changes: 5 additions & 0 deletions sparse/mlir_backend/_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ def np_dtype(self) -> np.dtype:
def to_ctype(self):
return rt.as_ctype(self.np_dtype)

def __eq__(self, value):
if np.isdtype(value) or isinstance(value, str):
value = asdtype(value)
return super().__eq__(value)


@dataclasses.dataclass(eq=True, frozen=True, kw_only=True)
class IeeeRealFloatingDType(DType):
Expand Down
5 changes: 2 additions & 3 deletions sparse/mlir_backend/tests/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def test_coo_3d_format(dtype):
@parametrize_dtypes
def test_sparse_vector_format(dtype):
if sparse.asdtype(dtype) in {sparse.complex64, sparse.complex128}:
pytest.xfail("Heisenbug")
pytest.xfail("The sparse_vector format returns incorrect results for complex dtypes.")
format = sparse.formats.Coo().with_ndim(1).with_dtype(dtype).build()

SHAPE = (10,)
Expand Down Expand Up @@ -465,8 +465,7 @@ def test_asformat(rng, src_fmt, dst_fmt):

expected = sps_arr.asformat(dst_fmt)

copy = None if dst_fmt == "coo" else False
actual_fmt = sparse.asarray(expected, copy=copy).format
actual_fmt = sparse.asarray(expected, copy=False).format
actual = sp_arr.asformat(actual_fmt)
actual_sps = sparse.to_scipy(actual)

Expand Down

0 comments on commit 128a567

Please sign in to comment.