Skip to content

Commit

Permalink
Add format generation from levels.
Browse files Browse the repository at this point in the history
  • Loading branch information
hameerabbasi committed Oct 14, 2024
1 parent d770f66 commit 6732b5e
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 1 deletion.
2 changes: 1 addition & 1 deletion pixi.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ mkdocs-jupyter = "*"
[feature.tests.tasks]
test = "pytest --pyargs sparse -n auto"
test-mlir = { cmd = "pytest --pyargs sparse/mlir_backend -vvv", env = { SPARSE_BACKEND = "MLIR" } }
test-finch = { cmd = "pytest --pyargs sparse/tests -n auto", env = { SPARSE_BACKEND = "Finch" }, depends-on = ["precompile"] }
test-finch = { cmd = "pytest --pyargs sparse/tests -n auto -vvv", env = { SPARSE_BACKEND = "Finch", PYTHONFAULTHANDLER = "${HOME}/faulthandler.log" }, depends-on = ["precompile"] }

[feature.tests.dependencies]
pytest = ">=3.5"
Expand Down
144 changes: 144 additions & 0 deletions sparse/mlir_backend/_levels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import ctypes
import dataclasses
import enum
import itertools
import re
import typing

Check warning on line 6 in sparse/mlir_backend/_levels.py

View check run for this annotation

Codecov / codecov/patch

sparse/mlir_backend/_levels.py#L1-L6

Added lines #L1 - L6 were not covered by tests

import mlir.runtime as rt
from mlir import ir
from mlir.dialects import sparse_tensor

Check warning on line 10 in sparse/mlir_backend/_levels.py

View check run for this annotation

Codecov / codecov/patch

sparse/mlir_backend/_levels.py#L8-L10

Added lines #L8 - L10 were not covered by tests

import numpy as np

Check warning on line 12 in sparse/mlir_backend/_levels.py

View check run for this annotation

Codecov / codecov/patch

sparse/mlir_backend/_levels.py#L12

Added line #L12 was not covered by tests

from ._common import MlirType, fn_cache
from ._dtypes import DType, asdtype

Check warning on line 15 in sparse/mlir_backend/_levels.py

View check run for this annotation

Codecov / codecov/patch

sparse/mlir_backend/_levels.py#L14-L15

Added lines #L14 - L15 were not covered by tests

_CAMEL_TO_SNAKE = [re.compile("(.)([A-Z][a-z]+)"), re.compile("([a-z0-9])([A-Z])")]

Check warning on line 17 in sparse/mlir_backend/_levels.py

View check run for this annotation

Codecov / codecov/patch

sparse/mlir_backend/_levels.py#L17

Added line #L17 was not covered by tests


def _camel_to_snake(name: str) -> str:
for exp in _CAMEL_TO_SNAKE:
name = exp.sub(r"\1_\2", name)

Check warning on line 22 in sparse/mlir_backend/_levels.py

View check run for this annotation

Codecov / codecov/patch

sparse/mlir_backend/_levels.py#L20-L22

Added lines #L20 - L22 were not covered by tests

return name.lower()

Check warning on line 24 in sparse/mlir_backend/_levels.py

View check run for this annotation

Codecov / codecov/patch

sparse/mlir_backend/_levels.py#L24

Added line #L24 was not covered by tests


@fn_cache
def get_nd_memref_descr(rank: int, dtype: type[DType]) -> type:
return rt.make_nd_memref_descriptor(rank, dtype.to_ctype())

Check warning on line 29 in sparse/mlir_backend/_levels.py

View check run for this annotation

Codecov / codecov/patch

sparse/mlir_backend/_levels.py#L27-L29

Added lines #L27 - L29 were not covered by tests


class LevelProperties(enum.Flag):
NonOrdered = enum.auto()
NonUnique = enum.auto()

Check warning on line 34 in sparse/mlir_backend/_levels.py

View check run for this annotation

Codecov / codecov/patch

sparse/mlir_backend/_levels.py#L32-L34

Added lines #L32 - L34 were not covered by tests

def build(self) -> list[sparse_tensor.LevelProperty]:
return [getattr(sparse_tensor.LevelProperty, _camel_to_snake(p.name)) for p in type(self) if p in self]

Check warning on line 37 in sparse/mlir_backend/_levels.py

View check run for this annotation

Codecov / codecov/patch

sparse/mlir_backend/_levels.py#L36-L37

Added lines #L36 - L37 were not covered by tests


class LevelFormat(enum.Enum):
Dense = "dense"
Compressed = "compressed"
Singleton = "singleton"

Check warning on line 43 in sparse/mlir_backend/_levels.py

View check run for this annotation

Codecov / codecov/patch

sparse/mlir_backend/_levels.py#L40-L43

Added lines #L40 - L43 were not covered by tests

def build(self) -> sparse_tensor.LevelFormat:
return getattr(sparse_tensor.LevelFormat, self.value)

Check warning on line 46 in sparse/mlir_backend/_levels.py

View check run for this annotation

Codecov / codecov/patch

sparse/mlir_backend/_levels.py#L45-L46

Added lines #L45 - L46 were not covered by tests


@dataclasses.dataclass
class Level:
format: LevelFormat
properties: LevelProperties = LevelProperties(0)

Check warning on line 52 in sparse/mlir_backend/_levels.py

View check run for this annotation

Codecov / codecov/patch

sparse/mlir_backend/_levels.py#L49-L52

Added lines #L49 - L52 were not covered by tests

def build(self):
sparse_tensor.EncodingAttr.build_level_type(self.format.build(), self.properties.build())

Check warning on line 55 in sparse/mlir_backend/_levels.py

View check run for this annotation

Codecov / codecov/patch

sparse/mlir_backend/_levels.py#L54-L55

Added lines #L54 - L55 were not covered by tests


@dataclasses.dataclass
class StorageFormat:
levels: tuple[Level, ...]
order: typing.Literal["C", "F"] | tuple[int, ...]
pos_width: int
crd_width: int
dtype: type[DType]

Check warning on line 64 in sparse/mlir_backend/_levels.py

View check run for this annotation

Codecov / codecov/patch

sparse/mlir_backend/_levels.py#L58-L64

Added lines #L58 - L64 were not covered by tests

@property
def storage_rank(self) -> int:
return len(self.levels)

Check warning on line 68 in sparse/mlir_backend/_levels.py

View check run for this annotation

Codecov / codecov/patch

sparse/mlir_backend/_levels.py#L66-L68

Added lines #L66 - L68 were not covered by tests

@property
def rank(self) -> int:
return self.storage_rank

Check warning on line 72 in sparse/mlir_backend/_levels.py

View check run for this annotation

Codecov / codecov/patch

sparse/mlir_backend/_levels.py#L70-L72

Added lines #L70 - L72 were not covered by tests

def __post_init__(self):
rank = self.storage_rank
self.dtype = asdtype(self.dtype)
if self.order == "C":
self.order = tuple(range(rank))
return

Check warning on line 79 in sparse/mlir_backend/_levels.py

View check run for this annotation

Codecov / codecov/patch

sparse/mlir_backend/_levels.py#L74-L79

Added lines #L74 - L79 were not covered by tests

if self.order == "F":
self.order = tuple(range(rank))[::-1]
return

Check warning on line 83 in sparse/mlir_backend/_levels.py

View check run for this annotation

Codecov / codecov/patch

sparse/mlir_backend/_levels.py#L81-L83

Added lines #L81 - L83 were not covered by tests

if sorted(self.order) != list(range(rank)):
raise ValueError(f"`sorted(self.order) != list(range(rank))`, {self.order=}, {rank=}.")

Check warning on line 86 in sparse/mlir_backend/_levels.py

View check run for this annotation

Codecov / codecov/patch

sparse/mlir_backend/_levels.py#L85-L86

Added lines #L85 - L86 were not covered by tests

self.order = tuple(self.order)

Check warning on line 88 in sparse/mlir_backend/_levels.py

View check run for this annotation

Codecov / codecov/patch

sparse/mlir_backend/_levels.py#L88

Added line #L88 was not covered by tests

@fn_cache
def get_mlir_type(self, *, shape: tuple[int, ...]) -> ir.RankedTensorType:
if len(shape) != self.rank:
raise ValueError(f"`len(shape) != self.rank`, {shape=}, {self.rank=}")
mlir_levels = [level.build() for level in self.levels]
mlir_order = list(self.order)
mlir_reverse_order = [0] * self.rank
for i, r in enumerate(mlir_order):
mlir_reverse_order[r] = i

Check warning on line 98 in sparse/mlir_backend/_levels.py

View check run for this annotation

Codecov / codecov/patch

sparse/mlir_backend/_levels.py#L90-L98

Added lines #L90 - L98 were not covered by tests

dtype = self.dtype.get_mlir_type()
encoding = sparse_tensor.EncodingAttr.get(

Check warning on line 101 in sparse/mlir_backend/_levels.py

View check run for this annotation

Codecov / codecov/patch

sparse/mlir_backend/_levels.py#L100-L101

Added lines #L100 - L101 were not covered by tests
mlir_levels, mlir_order, mlir_reverse_order, self.pos_width, self.crd_width
)
return ir.RankedTensorType.get(list(shape), dtype, encoding)

Check warning on line 104 in sparse/mlir_backend/_levels.py

View check run for this annotation

Codecov / codecov/patch

sparse/mlir_backend/_levels.py#L104

Added line #L104 was not covered by tests

@fn_cache
def get_ctypes_type(self):
ptr_dtype = asdtype(getattr(np, f"uint{self.pos_width}"))
idx_dtype = asdtype(getattr(np, f"uint{self.crd_width}"))

Check warning on line 109 in sparse/mlir_backend/_levels.py

View check run for this annotation

Codecov / codecov/patch

sparse/mlir_backend/_levels.py#L106-L109

Added lines #L106 - L109 were not covered by tests

def get_fields():
fields = []
compressed_counter = 0
for level, next_level in itertools.zip_longest(self.levels, self.levels[1:]):
if LevelFormat.Compressed == level.format:
compressed_counter += 1
fields.append((f"pointers_to_{compressed_counter}", get_nd_memref_descr(1, ptr_dtype)))
if next_level is not None and LevelFormat.Singleton == next_level.format:
fields.append((f"indices_{compressed_counter}", get_nd_memref_descr(2, idx_dtype)))

Check warning on line 119 in sparse/mlir_backend/_levels.py

View check run for this annotation

Codecov / codecov/patch

sparse/mlir_backend/_levels.py#L111-L119

Added lines #L111 - L119 were not covered by tests
else:
fields.append((f"indices_{compressed_counter}", get_nd_memref_descr(1, idx_dtype)))

Check warning on line 121 in sparse/mlir_backend/_levels.py

View check run for this annotation

Codecov / codecov/patch

sparse/mlir_backend/_levels.py#L121

Added line #L121 was not covered by tests

fields.append(("values", get_nd_memref_descr(1, self.dtype.np_dtype)))
return fields

Check warning on line 124 in sparse/mlir_backend/_levels.py

View check run for this annotation

Codecov / codecov/patch

sparse/mlir_backend/_levels.py#L123-L124

Added lines #L123 - L124 were not covered by tests

class Format(ctypes.Structure, MlirType):
_fields_ = get_fields()

Check warning on line 127 in sparse/mlir_backend/_levels.py

View check run for this annotation

Codecov / codecov/patch

sparse/mlir_backend/_levels.py#L126-L127

Added lines #L126 - L127 were not covered by tests

def get_mlir_type(self, *, shape: tuple[int, ...]):
return self.get_mlir_type(shape=shape)

Check warning on line 130 in sparse/mlir_backend/_levels.py

View check run for this annotation

Codecov / codecov/patch

sparse/mlir_backend/_levels.py#L129-L130

Added lines #L129 - L130 were not covered by tests

def to_module_arg(self) -> list:
return [ctypes.pointer(ctypes.pointer(f) for f in self.get__fields_())]

Check warning on line 133 in sparse/mlir_backend/_levels.py

View check run for this annotation

Codecov / codecov/patch

sparse/mlir_backend/_levels.py#L132-L133

Added lines #L132 - L133 were not covered by tests

def get__fields_(self) -> list:
return [getattr(self, field[0]) for field in self._fields_]

Check warning on line 136 in sparse/mlir_backend/_levels.py

View check run for this annotation

Codecov / codecov/patch

sparse/mlir_backend/_levels.py#L135-L136

Added lines #L135 - L136 were not covered by tests

return Format

Check warning on line 138 in sparse/mlir_backend/_levels.py

View check run for this annotation

Codecov / codecov/patch

sparse/mlir_backend/_levels.py#L138

Added line #L138 was not covered by tests

def __hash__(self):
return hash(id(self))

Check warning on line 141 in sparse/mlir_backend/_levels.py

View check run for this annotation

Codecov / codecov/patch

sparse/mlir_backend/_levels.py#L140-L141

Added lines #L140 - L141 were not covered by tests

def __eq__(self, value):
return self is value

Check warning on line 144 in sparse/mlir_backend/_levels.py

View check run for this annotation

Codecov / codecov/patch

sparse/mlir_backend/_levels.py#L143-L144

Added lines #L143 - L144 were not covered by tests

0 comments on commit 6732b5e

Please sign in to comment.