diff --git a/sparse/mlir_backend/_constructors.py b/sparse/mlir_backend/_constructors.py index b382c4f2..1f301908 100644 --- a/sparse/mlir_backend/_constructors.py +++ b/sparse/mlir_backend/_constructors.py @@ -49,6 +49,55 @@ def free_memref(obj: ctypes.Structure) -> None: ########### +@fn_cache +def get_sparse_vector_class( + values_dtype: type[DType], + index_dtype: type[DType], +) -> type[ctypes.Structure]: + class SparseVector(ctypes.Structure): + _fields_ = [ + ("indptr", get_nd_memref_descr(1, index_dtype)), + ("indices", get_nd_memref_descr(1, index_dtype)), + ("data", get_nd_memref_descr(1, values_dtype)), + ] + dtype = values_dtype + _index_dtype = index_dtype + + @classmethod + def from_sps(cls, arrs: list[np.ndarray]) -> "SparseVector": + sv_instance = cls(*[numpy_to_ranked_memref(arr) for arr in arrs]) + for arr in arrs: + _take_owneship(sv_instance, arr) + return sv_instance + + def to_sps(self, shape: tuple[int, ...]) -> int: + return PackedArgumentTuple(tuple(ranked_memref_to_numpy(field) for field in self.get__fields_())) + + def to_module_arg(self) -> list: + return [ + ctypes.pointer(ctypes.pointer(self.indptr)), + ctypes.pointer(ctypes.pointer(self.indices)), + ctypes.pointer(ctypes.pointer(self.data)), + ] + + def get__fields_(self) -> list: + return [self.indptr, self.indices, self.data] + + @classmethod + @fn_cache + def get_tensor_definition(cls, shape: tuple[int, ...]) -> ir.RankedTensorType: + with ir.Location.unknown(ctx): + values_dtype = cls.dtype.get_mlir_type() + index_dtype = cls._index_dtype.get_mlir_type() + index_width = getattr(index_dtype, "width", 0) + levels = (sparse_tensor.LevelFormat.compressed,) + ordering = ir.AffineMap.get_permutation([0]) + encoding = sparse_tensor.EncodingAttr.get(levels, ordering, ordering, index_width, index_width) + return ir.RankedTensorType.get(list(shape), values_dtype, encoding) + + return SparseVector + + @fn_cache def get_csx_class( values_dtype: type[DType], @@ -302,6 +351,16 @@ def get_csx_scipy_class(order: str) -> type[sps.sparray]: raise Exception(f"Invalid order: {order}") +_constructor_class_dict = { + "csr": get_csx_class, + "csc": get_csx_class, + "csf": get_csf_class, + "coo": get_coo_class, + "sparse_vector": get_sparse_vector_class, + "dense": get_dense_class, +} + + ################ # Tensor class # ################ @@ -346,8 +405,8 @@ def __init__( self._obj = obj elif format is not None: - if format in ["csf", "coo"]: - fn_format_class = get_csf_class if format == "csf" else get_coo_class + if format in ["csf", "coo", "sparse_vector"]: + fn_format_class = _constructor_class_dict[format] self._owns_memory = False self._index_dtype = asdtype(np.intp) self._format_class = fn_format_class(self._values_dtype, self._index_dtype) diff --git a/sparse/mlir_backend/tests/test_simple.py b/sparse/mlir_backend/tests/test_simple.py index 98ac90f2..0fb2e4d2 100644 --- a/sparse/mlir_backend/tests/test_simple.py +++ b/sparse/mlir_backend/tests/test_simple.py @@ -94,7 +94,7 @@ def test_dense_format(dtype, shape): @parametrize_dtypes -def test_constructors(rng, dtype): +def test_2d_constructors(rng, dtype): SHAPE = (80, 100) DENSITY = 0.6 sampler = generate_sampler(dtype, rng) @@ -219,6 +219,35 @@ def test_coo_3d_format(dtype): # np.testing.assert_array_equal(actual, expected) +@parametrize_dtypes +def test_sparse_vector_format(dtype): + SHAPE = (10,) + pos = np.array([0, 6]) + crd = np.array([0, 1, 2, 6, 8, 9]) + data = np.array([1, 2, 3, 4, 5, 6], dtype=dtype) + sparse_vector = [pos, crd, data] + + sv_tensor = sparse.asarray( + sparse_vector, + shape=SHAPE, + dtype=sparse.asdtype(dtype), + format="sparse_vector", + ) + result = sv_tensor.to_scipy_sparse() + for actual, expected in zip(result, sparse_vector, strict=False): + np.testing.assert_array_equal(actual, expected) + + res_tensor = sparse.add(sv_tensor, sv_tensor).to_scipy_sparse() + sparse_vector_2 = [pos, crd, data * 2] + for actual, expected in zip(res_tensor, sparse_vector_2, strict=False): + np.testing.assert_array_equal(actual, expected) + + dense = np.array([1, 2, 3, 0, 0, 0, 4, 0, 5, 6], dtype=dtype) + dense_tensor = sparse.asarray(dense) + res_tensor = sparse.add(dense_tensor, sv_tensor).to_scipy_sparse() + np.testing.assert_array_equal(res_tensor, dense * 2) + + @parametrize_dtypes def test_reshape(rng, dtype): DENSITY = 0.5