From 128a5676979ece95e02e848c03eb6939e160c190 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Tue, 3 Dec 2024 11:57:44 +0100 Subject: [PATCH] Allow no-copy construction from SciPy COO arrays. (#822) * Hold reference to converted scipy.sparse.coo_*. * Allow comparison against NumPy dtypes. --- sparse/mlir_backend/_conversions.py | 15 ++++++++------- sparse/mlir_backend/_dtypes.py | 5 +++++ sparse/mlir_backend/tests/test_simple.py | 5 ++--- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/sparse/mlir_backend/_conversions.py b/sparse/mlir_backend/_conversions.py index 28e07f95..e66eb408 100644 --- a/sparse/mlir_backend/_conversions.py +++ b/sparse/mlir_backend/_conversions.py @@ -78,8 +78,8 @@ def _from_scipy(arr: ScipySparseArray, copy: bool | None = None) -> Array: return from_constituent_arrays(format=csx_format, arrays=(indptr, indices, data), shape=arr.shape) case "coo": - if copy is not None and not copy: - raise RuntimeError(f"`scipy.sparse.{type(arr.__name__)}` cannot be zero-copy converted.") + from ._common import _hold_ref + row, col = arr.row, arr.col if row.dtype != col.dtype: raise RuntimeError(f"`row` and `col` dtypes must be the same: {row.dtype} != {col.dtype}.") @@ -89,10 +89,8 @@ def _from_scipy(arr: ScipySparseArray, copy: bool | None = None) -> Array: data = arr.data if copy: data = data.copy() - - # TODO: Make them own the data until https://github.com/llvm/llvm-project/issues/116012 is fixed. - row = row.copy() - col = col.copy() + row = row.copy() + col = col.copy() coo_format = ( Coo() @@ -103,7 +101,10 @@ def _from_scipy(arr: ScipySparseArray, copy: bool | None = None) -> Array: .build() ) - return from_constituent_arrays(format=coo_format, arrays=(pos, row, col, data), shape=arr.shape) + ret = from_constituent_arrays(format=coo_format, arrays=(pos, row, col, data), shape=arr.shape) + if not copy: + _hold_ref(ret, arr) + return ret case _: raise NotImplementedError(f"No conversion implemented for `scipy.sparse.{type(arr.__name__)}`.") diff --git a/sparse/mlir_backend/_dtypes.py b/sparse/mlir_backend/_dtypes.py index 7dad1438..d1613b80 100644 --- a/sparse/mlir_backend/_dtypes.py +++ b/sparse/mlir_backend/_dtypes.py @@ -33,6 +33,11 @@ def np_dtype(self) -> np.dtype: def to_ctype(self): return rt.as_ctype(self.np_dtype) + def __eq__(self, value): + if np.isdtype(value) or isinstance(value, str): + value = asdtype(value) + return super().__eq__(value) + @dataclasses.dataclass(eq=True, frozen=True, kw_only=True) class IeeeRealFloatingDType(DType): diff --git a/sparse/mlir_backend/tests/test_simple.py b/sparse/mlir_backend/tests/test_simple.py index 61271c7c..d0d75510 100644 --- a/sparse/mlir_backend/tests/test_simple.py +++ b/sparse/mlir_backend/tests/test_simple.py @@ -301,7 +301,7 @@ def test_coo_3d_format(dtype): @parametrize_dtypes def test_sparse_vector_format(dtype): if sparse.asdtype(dtype) in {sparse.complex64, sparse.complex128}: - pytest.xfail("Heisenbug") + pytest.xfail("The sparse_vector format returns incorrect results for complex dtypes.") format = sparse.formats.Coo().with_ndim(1).with_dtype(dtype).build() SHAPE = (10,) @@ -465,8 +465,7 @@ def test_asformat(rng, src_fmt, dst_fmt): expected = sps_arr.asformat(dst_fmt) - copy = None if dst_fmt == "coo" else False - actual_fmt = sparse.asarray(expected, copy=copy).format + actual_fmt = sparse.asarray(expected, copy=False).format actual = sp_arr.asformat(actual_fmt) actual_sps = sparse.to_scipy(actual)