diff --git a/pixi.toml b/pixi.toml index 01360600..b25a1269 100644 --- a/pixi.toml +++ b/pixi.toml @@ -27,7 +27,7 @@ mkdocs-jupyter = "*" [feature.tests.tasks] test = "pytest --pyargs sparse -n auto" -test-mlir = { cmd = "pytest --pyargs sparse/mlir_backend -v" } +test-mlir = { cmd = "pytest --pyargs sparse.mlir_backend -v" } test-finch = { cmd = "pytest --pyargs sparse/tests -n auto -v", depends-on = ["precompile"] } [feature.tests.dependencies] @@ -55,11 +55,14 @@ finch-tensor = ">=0.1.31" SPARSE_BACKEND = "Finch" [feature.finch.target.osx-arm64.activation.env] +SPARSE_BACKEND = "Finch" PYTHONFAULTHANDLER = "${HOME}/faulthandler.log" [feature.mlir.dependencies] scipy = ">=0.19" -mlir-python-bindings = "19.*" + +[feature.mlir.target.osx-arm64.pypi-dependencies] +finch-mlir = ">=0.0.2" [feature.mlir.activation.env] SPARSE_BACKEND = "MLIR" @@ -67,5 +70,5 @@ SPARSE_BACKEND = "MLIR" [environments] tests = ["tests", "extras"] docs = ["docs", "extras"] -mlir-dev = ["tests", "mlir"] -finch-dev = ["tests", "finch"] +mlir-dev = {features = ["tests", "mlir"], no-default-feature = true} +finch-dev = {features = ["tests", "finch"], no-default-feature = true} diff --git a/sparse/mlir_backend/__init__.py b/sparse/mlir_backend/__init__.py index 20a02beb..f60410ac 100644 --- a/sparse/mlir_backend/__init__.py +++ b/sparse/mlir_backend/__init__.py @@ -27,7 +27,7 @@ uint32, uint64, ) -from ._ops import add +from ._ops import add, reshape __all__ = [ "add", @@ -36,6 +36,7 @@ "to_numpy", "to_scipy", "levels", + "reshape", "from_constituent_arrays", "int8", "int16", diff --git a/sparse/mlir_backend/_array.py b/sparse/mlir_backend/_array.py index 50b863b0..ed88efc1 100644 --- a/sparse/mlir_backend/_array.py +++ b/sparse/mlir_backend/_array.py @@ -41,5 +41,10 @@ def copy(self) -> "Array": arrs = tuple(arr.copy() for arr in self.get_constituent_arrays()) return from_constituent_arrays(format=self.format, arrays=arrs, shape=self.shape) + def asformat(self, format: StorageFormat) -> "Array": + from ._ops import asformat + + return asformat(self, format=format) + def get_constituent_arrays(self) -> tuple[np.ndarray, ...]: return self._storage.get_constituent_arrays() diff --git a/sparse/mlir_backend/_common.py b/sparse/mlir_backend/_common.py index b382a822..66867e60 100644 --- a/sparse/mlir_backend/_common.py +++ b/sparse/mlir_backend/_common.py @@ -1,6 +1,7 @@ import ctypes import functools import weakref +from collections.abc import Iterable import mlir_finch.runtime as rt @@ -52,3 +53,13 @@ def finalizer(ptr): ctypes.pythonapi.Py_DecRef(ptr) weakref.finalize(owner, finalizer, ptr) + + +def as_shape(x) -> tuple[int]: + if not isinstance(x, Iterable): + x = (x,) + + if not all(isinstance(xi, int) for xi in x): + raise TypeError("Shape must be an `int` or tuple of `int`s.") + + return tuple(int(xi) for xi in x) diff --git a/sparse/mlir_backend/_core.py b/sparse/mlir_backend/_core.py index 9488ea46..ac7065cb 100644 --- a/sparse/mlir_backend/_core.py +++ b/sparse/mlir_backend/_core.py @@ -28,6 +28,12 @@ libc.free.argtypes = [ctypes.c_void_p] libc.free.restype = None +SHARED_LIBS = [] +if DEBUG: + SHARED_LIBS.append(MLIR_C_RUNNER_UTILS) + +OPT_LEVEL = 0 if DEBUG else 2 + # TODO: remove global state ctx = Context() diff --git a/sparse/mlir_backend/_dtypes.py b/sparse/mlir_backend/_dtypes.py index 31d8c5f8..7dad1438 100644 --- a/sparse/mlir_backend/_dtypes.py +++ b/sparse/mlir_backend/_dtypes.py @@ -76,10 +76,10 @@ def np_dtype(self) -> np.dtype: return np.dtype(getattr(np, f"uint{self.bit_width}")) -int8 = UnsignedIntegerDType(bit_width=8) -int16 = UnsignedIntegerDType(bit_width=16) -int32 = UnsignedIntegerDType(bit_width=32) -int64 = UnsignedIntegerDType(bit_width=64) +uint8 = UnsignedIntegerDType(bit_width=8) +uint16 = UnsignedIntegerDType(bit_width=16) +uint32 = UnsignedIntegerDType(bit_width=32) +uint64 = UnsignedIntegerDType(bit_width=64) @dataclasses.dataclass(eq=True, frozen=True, kw_only=True) @@ -89,10 +89,10 @@ def np_dtype(self) -> np.dtype: return np.dtype(getattr(np, f"int{self.bit_width}")) -uint8 = SignedIntegerDType(bit_width=8) -uint16 = SignedIntegerDType(bit_width=16) -uint32 = SignedIntegerDType(bit_width=32) -uint64 = SignedIntegerDType(bit_width=64) +int8 = SignedIntegerDType(bit_width=8) +int16 = SignedIntegerDType(bit_width=16) +int32 = SignedIntegerDType(bit_width=32) +int64 = SignedIntegerDType(bit_width=64) intp: SignedIntegerDType = locals()[f"int{_PTR_WIDTH}"] diff --git a/sparse/mlir_backend/_ops.py b/sparse/mlir_backend/_ops.py index 20eee897..029df872 100644 --- a/sparse/mlir_backend/_ops.py +++ b/sparse/mlir_backend/_ops.py @@ -1,14 +1,18 @@ import ctypes +import math import mlir_finch.execution_engine import mlir_finch.passmanager from mlir_finch import ir from mlir_finch.dialects import arith, complex, func, linalg, sparse_tensor, tensor +import numpy as np + from ._array import Array -from ._common import fn_cache -from ._core import CWD, DEBUG, SHARED_LIBS, ctx, pm +from ._common import as_shape, fn_cache +from ._core import CWD, DEBUG, OPT_LEVEL, SHARED_LIBS, ctx, pm from ._dtypes import DType, IeeeComplexFloatingDType, IeeeRealFloatingDType, IntegerDType +from .levels import StorageFormat, _determine_format @fn_cache @@ -17,7 +21,6 @@ def get_add_module( b_tensor_type: ir.RankedTensorType, out_tensor_type: ir.RankedTensorType, dtype: DType, - rank: int, ) -> ir.Module: with ir.Location.unknown(ctx): module = ir.Module.create() @@ -31,7 +34,7 @@ def get_add_module( raise RuntimeError(f"Can not add {dtype=}.") dtype = dtype._get_mlir_type() - ordering = ir.AffineMap.get_permutation(range(rank)) + max_rank = out_tensor_type.rank with ir.InsertionPoint(module.body): @@ -42,8 +45,13 @@ def add(a, b): [out_tensor_type], [a, b], [out], - ir.ArrayAttr.get([ir.AffineMapAttr.get(p) for p in (ordering,) * 3]), - ir.ArrayAttr.get([ir.Attribute.parse("#linalg.iterator_type")] * rank), + ir.ArrayAttr.get( + [ + ir.AffineMapAttr.get(ir.AffineMap.get_minor_identity(max_rank, t.rank)) + for t in (a_tensor_type, b_tensor_type, out_tensor_type) + ] + ), + ir.ArrayAttr.get([ir.Attribute.parse("#linalg.iterator_type")] * max_rank), ) block = generic_op.regions[0].blocks.append(dtype, dtype, dtype) with ir.InsertionPoint(block): @@ -72,7 +80,7 @@ def add(a, b): if DEBUG: (CWD / "add_module_opt.mlir").write_text(str(module)) - return mlir_finch.execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=SHARED_LIBS) + return mlir_finch.execution_engine.ExecutionEngine(module, opt_level=OPT_LEVEL, shared_libs=SHARED_LIBS) @fn_cache @@ -97,7 +105,7 @@ def reshape(a, shape): if DEBUG: (CWD / "reshape_module_opt.mlir").write_text(str(module)) - return mlir_finch.execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=SHARED_LIBS) + return mlir_finch.execution_engine.ExecutionEngine(module, opt_level=OPT_LEVEL, shared_libs=SHARED_LIBS) @fn_cache @@ -125,21 +133,44 @@ def broadcast_to(in_tensor): if DEBUG: (CWD / "broadcast_to_module_opt.mlir").write_text(str(module)) - return mlir_finch.execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=SHARED_LIBS) + return mlir_finch.execution_engine.ExecutionEngine(module, opt_level=OPT_LEVEL, shared_libs=SHARED_LIBS) + + +@fn_cache +def get_convert_module( + in_tensor_type: ir.RankedTensorType, + out_tensor_type: ir.RankedTensorType, +): + with ir.Location.unknown(ctx): + module = ir.Module.create() + + with ir.InsertionPoint(module.body): + @func.FuncOp.from_py_func(in_tensor_type) + def convert(in_tensor): + return sparse_tensor.convert(out_tensor_type, in_tensor) -def add(x1: Array, x2: Array) -> Array: - ret_storage_format = x1.format + convert.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + if DEBUG: + (CWD / "convert_module.mlir").write_text(str(module)) + pm.run(module.operation) + if DEBUG: + (CWD / "convert_module.mlir").write_text(str(module)) + + return mlir_finch.execution_engine.ExecutionEngine(module, opt_level=OPT_LEVEL, shared_libs=SHARED_LIBS) + + +def add(x1: Array, x2: Array, /) -> Array: + # TODO: Determine output format via autoscheduler + ret_storage_format = _determine_format(x1.format, x2.format, dtype=x1.dtype, union=True) ret_storage = ret_storage_format._get_ctypes_type(owns_memory=True)() - out_tensor_type = ret_storage_format._get_mlir_type(shape=x1.shape) + out_tensor_type = ret_storage_format._get_mlir_type(shape=np.broadcast_shapes(x1.shape, x2.shape)) - # TODO: Decide what will be the output tensor_type add_module = get_add_module( x1._get_mlir_type(), x2._get_mlir_type(), out_tensor_type=out_tensor_type, dtype=x1.dtype, - rank=x1.ndim, ) add_module.invoke( "add", @@ -147,4 +178,49 @@ def add(x1: Array, x2: Array) -> Array: *x1._to_module_arg(), *x2._to_module_arg(), ) - return Array(storage=ret_storage, shape=out_tensor_type.shape) + return Array(storage=ret_storage, shape=tuple(out_tensor_type.shape)) + + +def asformat(x: Array, /, format: StorageFormat) -> Array: + if x.format == format: + return x + + out_tensor_type = format._get_mlir_type(shape=x.shape) + ret_storage = format._get_ctypes_type(owns_memory=True)() + + convert_module = get_convert_module( + x._get_mlir_type(), + out_tensor_type, + ) + + convert_module.invoke( + "convert", + ctypes.pointer(ctypes.pointer(ret_storage)), + *x._to_module_arg(), + ) + + return Array(storage=ret_storage, shape=x.shape) + + +def reshape(x: Array, /, shape: tuple[int, ...]) -> Array: + from ._conversions import _from_numpy + + shape = as_shape(shape) + if math.prod(x.shape) != math.prod(shape): + raise ValueError(f"`math.prod(x.shape) != math.prod(shape)`, {x.shape=}, {shape=}") + + ret_storage_format = _determine_format(x.format, dtype=x.dtype, union=len(shape) > x.ndim, out_ndim=len(shape)) + shape_array = _from_numpy(np.asarray(shape, dtype=np.uint64)) + out_tensor_type = ret_storage_format._get_mlir_type(shape=shape) + ret_storage = ret_storage_format._get_ctypes_type(owns_memory=True)() + + reshape_module = get_reshape_module(x._get_mlir_type(), shape_array._get_mlir_type(), out_tensor_type) + + reshape_module.invoke( + "reshape", + ctypes.pointer(ctypes.pointer(ret_storage)), + *x._to_module_arg(), + *shape_array._to_module_arg(), + ) + + return Array(storage=ret_storage, shape=shape) diff --git a/sparse/mlir_backend/levels.py b/sparse/mlir_backend/levels.py index 6559ec4c..c7021523 100644 --- a/sparse/mlir_backend/levels.py +++ b/sparse/mlir_backend/levels.py @@ -209,3 +209,100 @@ def _get_storage_format( crd_width=crd_width, dtype=dtype, ) + + +def _is_sparse_level(lvl: Level | LevelFormat, /) -> bool: + assert isinstance(lvl, Level | LevelFormat) + if isinstance(lvl, Level): + lvl = lvl.format + return LevelFormat.Dense != lvl + + +def _count_sparse_levels(format: StorageFormat) -> int: + return sum(_is_sparse_level(lvl) for lvl in format.levels) + + +def _count_dense_levels(format: StorageFormat) -> int: + return sum(not _is_sparse_level(lvl) for lvl in format.levels) + + +def _get_sparse_dense_levels( + *, n_sparse: int | None = None, n_dense: int | None = None, ndim: int | None = None +) -> tuple[Level, ...]: + if (n_sparse is not None) + (n_dense is not None) + (ndim is not None) != 2: + assert n_sparse is not None and n_dense is not None and ndim is not None # + assert n_sparse + n_dense == ndim + if n_sparse is None: + n_sparse = ndim - n_dense + if n_dense is None: + n_dense = ndim - n_sparse + if ndim is None: + ndim = n_dense + n_sparse + + assert ndim >= 0 + assert n_dense >= 0 + assert n_sparse >= 0 + + return (Level(LevelFormat.Dense),) * n_dense + (Level(LevelFormat.Compressed),) * n_sparse + + +def _determine_format(*formats: StorageFormat, dtype: DType, union: bool, out_ndim: int | None = None) -> StorageFormat: + """Determines the output format from a group of input formats. + + 1. Counts the sparse levels for `union=True`, and dense ones for `union=False`. + 2. Gets the max number of counted levels for each format. + 3. Constructs a format with rank of `out_ndim` (max rank of inputs is taken if it's `None`). + If `union=False` counted levels is the number of sparse levels, otherwise dense. + Sparse levels are replaced with `LevelFormat.Compressed`. + + Returns + ------- + StorageFormat + Output storage format. + """ + if len(formats) == 0: + if out_ndim is None: + out_ndim = 0 + return get_storage_format( + levels=(Level(LevelFormat.Dense if union else LevelFormat.Compressed),) * out_ndim, + order="C", + pos_width=64, + crd_width=64, + dtype=dtype, + ) + + if out_ndim is None: + out_ndim = max(fmt.rank for fmt in formats) + + pos_width = 0 + crd_width = 0 + counter = _count_sparse_levels if not union else _count_dense_levels + n_counted = None + order = () + for fmt in formats: + n_counted = counter(fmt) if n_counted is None else max(n_counted, counter(fmt)) + pos_width = max(pos_width, fmt.pos_width) + crd_width = max(crd_width, fmt.crd_width) + if order != "C": + if fmt.order[: len(order)] == order: + order = fmt.order + elif order[: len(fmt.order)] != fmt.order: + order = "C" + + if not isinstance(order, str): + order = order + tuple(range(len(order), out_ndim)) + order = order[:out_ndim] + + if out_ndim < n_counted: + n_counted = out_ndim + + n_sparse = n_counted if not union else out_ndim - n_counted + + levels = _get_sparse_dense_levels(n_sparse=n_sparse, ndim=out_ndim) + return get_storage_format( + levels=levels, + order=order, + pos_width=pos_width, + crd_width=crd_width, + dtype=dtype, + ) diff --git a/sparse/mlir_backend/tests/test_simple.py b/sparse/mlir_backend/tests/test_simple.py index f8ae1f31..2f4fe12c 100644 --- a/sparse/mlir_backend/tests/test_simple.py +++ b/sparse/mlir_backend/tests/test_simple.py @@ -31,9 +31,9 @@ ) -def assert_csx_equal( - expected: sps.csr_array | sps.csc_array, - actual: sps.csr_array | sps.csc_array, +def assert_sps_equal( + expected: sps.csr_array | sps.csc_array | sps.coo_array, + actual: sps.csr_array | sps.csc_array | sps.coo_array, ) -> None: assert expected.format == actual.format expected.eliminate_zeros() @@ -42,8 +42,13 @@ def assert_csx_equal( actual.eliminate_zeros() actual.sum_duplicates() - np.testing.assert_array_equal(expected.indptr, actual.indptr) - np.testing.assert_array_equal(expected.indices, actual.indices) + if expected.format != "coo": + np.testing.assert_array_equal(expected.indptr, actual.indptr) + np.testing.assert_array_equal(expected.indices, actual.indices) + else: + np.testing.assert_array_equal(expected.row, actual.row) + np.testing.assert_array_equal(expected.col, actual.col) + np.testing.assert_array_equal(expected.data, actual.data) @@ -85,7 +90,7 @@ def sampler_complex_floating(size: tuple[int, ...]): raise NotImplementedError(f"{dtype=} not yet supported.") -def get_exampe_csf_arrays(dtype: np.dtype) -> tuple: +def get_example_csf_arrays(dtype: np.dtype) -> tuple: pos_1 = np.array([0, 1, 3], dtype=np.int64) crd_1 = np.array([1, 0, 1], dtype=np.int64) pos_2 = np.array([0, 3, 5, 7], dtype=np.int64) @@ -121,10 +126,10 @@ def test_2d_constructors(rng, dtype): dense_2_tensor = sparse.asarray(np.arange(100, dtype=dtype).reshape((25, 4)) + 10) csr_retured = sparse.to_scipy(csr_tensor) - assert_csx_equal(csr_retured, csr) + assert_sps_equal(csr_retured, csr) csc_retured = sparse.to_scipy(csc_tensor) - assert_csx_equal(csc_retured, csc) + assert_sps_equal(csc_retured, csc) dense_returned = sparse.to_numpy(dense_tensor) np.testing.assert_equal(dense_returned, dense) @@ -157,19 +162,19 @@ def test_add(rng, dtype): actual = sparse.to_scipy(sparse.add(csr_tensor, csr_2_tensor)) expected = csr + csr_2 - assert_csx_equal(expected, actual) + assert_sps_equal(expected, actual) actual = sparse.to_scipy(sparse.add(csc_tensor, csc_tensor)) expected = csc + csc - assert_csx_equal(expected, actual) + assert_sps_equal(expected, actual) actual = sparse.to_scipy(sparse.add(csc_tensor, csr_tensor)) - expected = csc + csr - assert_csx_equal(expected, actual) + expected = (csc + csr).asformat("csr") + assert_sps_equal(expected, actual) - actual = sparse.to_scipy(sparse.add(csr_tensor, dense_tensor)) - expected = sps.csr_matrix(csr + dense) - assert_csx_equal(expected, actual) + actual = sparse.to_numpy(sparse.add(csr_tensor, dense_tensor)) + expected = csr + dense + np.testing.assert_array_equal(actual, expected) actual = sparse.to_numpy(sparse.add(dense_tensor, csr_tensor)) expected = csr + dense @@ -183,9 +188,11 @@ def test_add(rng, dtype): actual = sparse.to_scipy(sparse.add(csr_2_tensor, coo_tensor)) expected = csr_2 + coo - assert_csx_equal(expected, actual) + assert_sps_equal(expected, actual) - actual = sparse.to_scipy(sparse.add(coo_tensor, coo_tensor)) + # This ends up being DCSR, not COO + actual_tensor = sparse.add(coo_tensor, coo_tensor) + actual = sparse.to_scipy(actual_tensor.asformat(coo_tensor.format)) expected = coo + coo np.testing.assert_array_equal(actual.todense(), expected.todense()) @@ -205,7 +212,7 @@ def test_csf_format(dtype): ) SHAPE = (2, 2, 4) - pos_1, crd_1, pos_2, crd_2, data = get_exampe_csf_arrays(dtype) + pos_1, crd_1, pos_2, crd_2, data = get_example_csf_arrays(dtype) constituent_arrays = (pos_1, crd_1, pos_2, crd_2, data) csf_array = sparse.from_constituent_arrays(format=format, arrays=constituent_arrays, shape=SHAPE) @@ -247,7 +254,7 @@ def test_coo_3d_format(dtype): for actual, expected in zip(result, carrs, strict=True): np.testing.assert_array_equal(actual, expected) - result_arrays = sparse.add(coo_array, coo_array).get_constituent_arrays() + result_arrays = sparse.add(coo_array, coo_array).asformat(coo_array.format).get_constituent_arrays() constituent_arrays = (pos, *crd, data * 2) for actual, expected in zip(result_arrays, constituent_arrays, strict=False): np.testing.assert_array_equal(actual, expected) @@ -297,3 +304,173 @@ def test_copy(): np.testing.assert_array_equal(sparse.to_numpy(arr_sp1), arr_np_orig) np.testing.assert_array_equal(sparse.to_numpy(arr_sp2), arr_np_orig) np.testing.assert_array_equal(sparse.to_numpy(arr_sp3), arr_np_copy) + + +@parametrize_dtypes +@pytest.mark.parametrize( + "format", + [ + "csr", + pytest.param("csc", marks=pytest.mark.xfail(reason="https://github.com/llvm/llvm-project/pull/109641")), + "coo", + ], +) +@pytest.mark.parametrize( + ("shape", "new_shape"), + [ + ((100, 50), (25, 200)), + ((100, 50), (10, 500, 1)), + ((80, 1), (8, 10)), + ((80, 1), (80,)), + ], +) +def test_reshape(rng, dtype, format, shape, new_shape): + DENSITY = 0.5 + sampler = generate_sampler(dtype, rng) + + arr_sps = sps.random_array( + shape, density=DENSITY, format=format, dtype=dtype, random_state=rng, data_sampler=sampler + ) + arr_sps.eliminate_zeros() + arr_sps.sum_duplicates() + arr = sparse.asarray(arr_sps) + + actual = sparse.reshape(arr, shape=new_shape) + assert actual.shape == new_shape + + try: + scipy_format = sparse.to_scipy(actual).format + except RuntimeError: + tmp_levels = (sparse.levels.Level(sparse.levels.LevelFormat.Dense),) * len(shape) + tmp_fmt = sparse.levels.get_storage_format( + levels=tmp_levels, + order="C", + pos_width=64, + crd_width=64, + dtype=dtype, + ) + arr_dense = arr.asformat(tmp_fmt) + arr_np = sparse.to_numpy(arr_dense) + expected_np = arr_np.reshape(new_shape) + + out_levels = (sparse.levels.Level(sparse.levels.LevelFormat.Dense),) * len(new_shape) + out_fmt = sparse.levels.get_storage_format( + levels=out_levels, + order="C", + pos_width=64, + crd_width=64, + dtype=dtype, + ) + actual_dense = actual.asformat(out_fmt) + actual_np = sparse.to_numpy(actual_dense) + + np.testing.assert_array_equal(expected_np, actual_np) + return + + expected = sparse.asarray(arr_sps.reshape(new_shape).asformat(scipy_format)) + + for x, y in zip(expected.get_constituent_arrays(), actual.get_constituent_arrays(), strict=True): + np.testing.assert_array_equal(x, y) + + +@parametrize_dtypes +def test_reshape_csf(dtype): + # CSF + csf_shape = (2, 2, 4) + csf_format = sparse.levels.get_storage_format( + levels=( + sparse.levels.Level(sparse.levels.LevelFormat.Dense), + sparse.levels.Level(sparse.levels.LevelFormat.Compressed), + sparse.levels.Level(sparse.levels.LevelFormat.Compressed), + ), + order="C", + pos_width=64, + crd_width=64, + dtype=sparse.asdtype(dtype), + ) + for shape, new_shape, expected_arrs in [ + ( + csf_shape, + (4, 4, 1), + [ + np.array([0, 0, 3, 5, 7]), + np.array([0, 1, 3, 0, 3, 0, 1]), + np.array([0, 1, 2, 3, 4, 5, 6, 7]), + np.array([0, 0, 0, 0, 0, 0, 0]), + np.array([1, 2, 3, 4, 5, 6, 7]), + ], + ), + ( + csf_shape, + (2, 1, 8), + [ + np.array([0, 1, 2]), + np.array([0, 0]), + np.array([0, 3, 7]), + np.array([4, 5, 7, 0, 3, 4, 5]), + np.array([1, 2, 3, 4, 5, 6, 7]), + ], + ), + ]: + arrs = get_example_csf_arrays(dtype) + csf_tensor = sparse.from_constituent_arrays(format=csf_format, arrays=arrs, shape=shape) + + result = sparse.reshape(csf_tensor, shape=new_shape) + for actual, expected in zip(result.get_constituent_arrays(), expected_arrs, strict=True): + np.testing.assert_array_equal(actual, expected) + + +@parametrize_dtypes +def test_reshape_dense(dtype): + SHAPE = (2, 2, 4) + + np_arr = np.arange(math.prod(SHAPE), dtype=dtype).reshape(SHAPE) + sp_arr = sparse.asarray(np_arr) + + for new_shape in [ + (4, 4, 1), + (2, 1, 8), + ]: + expected = np_arr.reshape(new_shape) + actual = sparse.reshape(sp_arr, new_shape) + + actual_np = sparse.to_numpy(actual) + + assert actual_np.dtype == expected.dtype + np.testing.assert_equal(actual_np, expected) + + +@pytest.mark.parametrize( + "src_fmt", + [ + "csr", + "csc", + pytest.param("coo", marks=pytest.mark.skip(reason="https://github.com/llvm/llvm-project/issues/116012")), + ], +) +@pytest.mark.parametrize( + "dst_fmt", + [ + "csr", + "csc", + pytest.param("coo", marks=pytest.mark.skip(reason="https://github.com/llvm/llvm-project/issues/116012")), + ], +) +def test_asformat(rng, src_fmt, dst_fmt): + SHAPE = (100, 50) + DENSITY = 0.5 + sampler = generate_sampler(np.float64, rng) + + sps_arr = sps.random_array( + SHAPE, density=DENSITY, format=src_fmt, dtype=np.float64, random_state=rng, data_sampler=sampler + ) + sp_arr = sparse.asarray(sps_arr) + + expected = sps_arr.asformat(dst_fmt) + + actual_fmt = sparse.asarray(expected, copy=False).format + actual = sp_arr.asformat(actual_fmt) + actual_sps = sparse.to_scipy(actual) + + assert actual_sps.format == dst_fmt + assert_sps_equal(expected, actual_sps)