diff --git a/pixi.toml b/pixi.toml index 52e57fb3..5e68051e 100644 --- a/pixi.toml +++ b/pixi.toml @@ -28,7 +28,7 @@ mkdocs-jupyter = "*" [feature.tests.tasks] test = "pytest --pyargs sparse -n auto" test-mlir = { cmd = "pytest --pyargs sparse/mlir_backend -vvv", env = { SPARSE_BACKEND = "MLIR" } } -test-finch = { cmd = "pytest --pyargs sparse/tests -n auto", env = { SPARSE_BACKEND = "Finch" }, depends-on = ["precompile"] } +test-finch = { cmd = "pytest --pyargs sparse/tests -n auto -vvv", env = { SPARSE_BACKEND = "Finch", PYTHONFAULTHANDLER = "${HOME}/faulthandler.log" }, depends-on = ["precompile"] } [feature.tests.dependencies] pytest = ">=3.5" diff --git a/sparse/mlir_backend/_levels.py b/sparse/mlir_backend/_levels.py new file mode 100644 index 00000000..8f110dc1 --- /dev/null +++ b/sparse/mlir_backend/_levels.py @@ -0,0 +1,144 @@ +import ctypes +import dataclasses +import enum +import itertools +import re +import typing + +import mlir.runtime as rt +from mlir import ir +from mlir.dialects import sparse_tensor + +import numpy as np + +from ._common import MlirType, fn_cache +from ._dtypes import DType, asdtype + +_CAMEL_TO_SNAKE = [re.compile("(.)([A-Z][a-z]+)"), re.compile("([a-z0-9])([A-Z])")] + + +def _camel_to_snake(name: str) -> str: + for exp in _CAMEL_TO_SNAKE: + name = exp.sub(r"\1_\2", name) + + return name.lower() + + +@fn_cache +def get_nd_memref_descr(rank: int, dtype: type[DType]) -> type: + return rt.make_nd_memref_descriptor(rank, dtype.to_ctype()) + + +class LevelProperties(enum.Flag): + NonOrdered = enum.auto() + NonUnique = enum.auto() + + def build(self) -> list[sparse_tensor.LevelProperty]: + return [getattr(sparse_tensor.LevelProperty, _camel_to_snake(p.name)) for p in type(self) if p in self] + + +class LevelFormat(enum.Enum): + Dense = "dense" + Compressed = "compressed" + Singleton = "singleton" + + def build(self) -> sparse_tensor.LevelFormat: + return getattr(sparse_tensor.LevelFormat, self.value) + + +@dataclasses.dataclass +class Level: + format: LevelFormat + properties: LevelProperties = LevelProperties(0) + + def build(self): + sparse_tensor.EncodingAttr.build_level_type(self.format.build(), self.properties.build()) + + +@dataclasses.dataclass +class StorageFormat: + levels: tuple[Level, ...] + order: typing.Literal["C", "F"] | tuple[int, ...] + pos_width: int + crd_width: int + dtype: type[DType] + + @property + def storage_rank(self) -> int: + return len(self.levels) + + @property + def rank(self) -> int: + return self.storage_rank + + def __post_init__(self): + rank = self.storage_rank + self.dtype = asdtype(self.dtype) + if self.order == "C": + self.order = tuple(range(rank)) + return + + if self.order == "F": + self.order = tuple(range(rank))[::-1] + return + + if sorted(self.order) != list(range(rank)): + raise ValueError(f"`sorted(self.order) != list(range(rank))`, {self.order=}, {rank=}.") + + self.order = tuple(self.order) + + @fn_cache + def get_mlir_type(self, *, shape: tuple[int, ...]) -> ir.RankedTensorType: + if len(shape) != self.rank: + raise ValueError(f"`len(shape) != self.rank`, {shape=}, {self.rank=}") + mlir_levels = [level.build() for level in self.levels] + mlir_order = list(self.order) + mlir_reverse_order = [0] * self.rank + for i, r in enumerate(mlir_order): + mlir_reverse_order[r] = i + + dtype = self.dtype.get_mlir_type() + encoding = sparse_tensor.EncodingAttr.get( + mlir_levels, mlir_order, mlir_reverse_order, self.pos_width, self.crd_width + ) + return ir.RankedTensorType.get(list(shape), dtype, encoding) + + @fn_cache + def get_ctypes_type(self): + ptr_dtype = asdtype(getattr(np, f"uint{self.pos_width}")) + idx_dtype = asdtype(getattr(np, f"uint{self.crd_width}")) + + def get_fields(): + fields = [] + compressed_counter = 0 + for level, next_level in itertools.zip_longest(self.levels, self.levels[1:]): + if LevelFormat.Compressed == level.format: + compressed_counter += 1 + fields.append((f"pointers_to_{compressed_counter}", get_nd_memref_descr(1, ptr_dtype))) + if next_level is not None and LevelFormat.Singleton == next_level.format: + fields.append((f"indices_{compressed_counter}", get_nd_memref_descr(2, idx_dtype))) + else: + fields.append((f"indices_{compressed_counter}", get_nd_memref_descr(1, idx_dtype))) + + fields.append(("values", get_nd_memref_descr(1, self.dtype.np_dtype))) + return fields + + class Format(ctypes.Structure, MlirType): + _fields_ = get_fields() + + def get_mlir_type(self, *, shape: tuple[int, ...]): + return self.get_mlir_type(shape=shape) + + def to_module_arg(self) -> list: + return [ctypes.pointer(ctypes.pointer(f) for f in self.get__fields_())] + + def get__fields_(self) -> list: + return [getattr(self, field[0]) for field in self._fields_] + + return Format + + def __hash__(self): + return hash(id(self)) + + def __eq__(self, value): + return self is value