Skip to content

Commit

Permalink
Simplify _from_scipy function.
Browse files Browse the repository at this point in the history
  • Loading branch information
hameerabbasi committed Oct 22, 2024
1 parent 0027e00 commit a81c8d4
Showing 1 changed file with 16 additions and 30 deletions.
46 changes: 16 additions & 30 deletions sparse/mlir_backend/_conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,32 +88,26 @@ def _from_scipy(arr: ScipySparseArray, copy: bool | None = None) -> Array:
owns_memory=False,
)

indptr_np = arr.indptr
indices_np = arr.indices
data_np = arr.data
indptr = arr.indptr
indices = arr.indices
data = arr.data

if copy:
indptr_np = indptr_np.copy()
indices_np = indices_np.copy()
data_np = data_np.copy()

indptr = numpy_to_ranked_memref(indptr_np)
indices = numpy_to_ranked_memref(indices_np)
data = numpy_to_ranked_memref(data_np)

storage = csr_format._get_ctypes_type()(indptr, indices, data)
_hold_ref(storage, indptr_np)
_hold_ref(storage, indices_np)
_hold_ref(storage, data_np)
return Array(storage=storage, shape=arr.shape)
indptr = indptr.copy()
indices = indices.copy()
data = data.copy()

return from_constituent_arrays(format=csr_format, arrays=(indptr, indices, data), shape=arr.shape)
case "coo":
if copy is not None and not copy:
raise RuntimeError(f"`scipy.sparse.{type(arr.__name__)}` cannot be zero-copy converted.")
coords_np = np.stack([arr.row, arr.col], axis=1)
pos_np = np.array([0, arr.nnz], dtype=np.int64)
pos_width = pos_np.dtype.itemsize * 8
crd_width = coords_np.dtype.itemsize * 8
data_np = arr.data.copy()
coords = np.stack([arr.row, arr.col], axis=1)
pos = np.array([0, arr.nnz], dtype=np.int64)
pos_width = pos.dtype.itemsize * 8
crd_width = coords.dtype.itemsize * 8
data = arr.data
if copy:
data = arr.data.copy()

level_props = LevelProperties(0)
if not arr.has_canonical_format:
Expand All @@ -131,15 +125,7 @@ def _from_scipy(arr: ScipySparseArray, copy: bool | None = None) -> Array:
owns_memory=False,
)

pos = numpy_to_ranked_memref(pos_np)
crd = numpy_to_ranked_memref(coords_np)
data = numpy_to_ranked_memref(data_np)

storage = coo_format._get_ctypes_type()(pos, crd, data)
_hold_ref(storage, pos_np)
_hold_ref(storage, coords_np)
_hold_ref(storage, data_np)
return Array(storage=storage, shape=arr.shape)
return from_constituent_arrays(format=coo_format, arrays=(pos, coords, data), shape=arr.shape)
case _:
raise NotImplementedError(f"No conversion implemented for `scipy.sparse.{type(arr.__name__)}`.")

Expand Down

0 comments on commit a81c8d4

Please sign in to comment.