Skip to content

Commit

Permalink
Add back reshape (#800)
Browse files Browse the repository at this point in the history
  • Loading branch information
hameerabbasi authored Nov 13, 2024
1 parent e111324 commit 4b491b5
Show file tree
Hide file tree
Showing 9 changed files with 423 additions and 47 deletions.
11 changes: 7 additions & 4 deletions pixi.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ mkdocs-jupyter = "*"

[feature.tests.tasks]
test = "pytest --pyargs sparse -n auto"
test-mlir = { cmd = "pytest --pyargs sparse/mlir_backend -v" }
test-mlir = { cmd = "pytest --pyargs sparse.mlir_backend -v" }
test-finch = { cmd = "pytest --pyargs sparse/tests -n auto -v", depends-on = ["precompile"] }

[feature.tests.dependencies]
Expand Down Expand Up @@ -55,17 +55,20 @@ finch-tensor = ">=0.1.31"
SPARSE_BACKEND = "Finch"

[feature.finch.target.osx-arm64.activation.env]
SPARSE_BACKEND = "Finch"
PYTHONFAULTHANDLER = "${HOME}/faulthandler.log"

[feature.mlir.dependencies]
scipy = ">=0.19"
mlir-python-bindings = "19.*"

[feature.mlir.target.osx-arm64.pypi-dependencies]
finch-mlir = ">=0.0.2"

[feature.mlir.activation.env]
SPARSE_BACKEND = "MLIR"

[environments]
tests = ["tests", "extras"]
docs = ["docs", "extras"]
mlir-dev = ["tests", "mlir"]
finch-dev = ["tests", "finch"]
mlir-dev = {features = ["tests", "mlir"], no-default-feature = true}
finch-dev = {features = ["tests", "finch"], no-default-feature = true}
3 changes: 2 additions & 1 deletion sparse/mlir_backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
uint32,
uint64,
)
from ._ops import add
from ._ops import add, reshape

__all__ = [
"add",
Expand All @@ -36,6 +36,7 @@
"to_numpy",
"to_scipy",
"levels",
"reshape",
"from_constituent_arrays",
"int8",
"int16",
Expand Down
5 changes: 5 additions & 0 deletions sparse/mlir_backend/_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,10 @@ def copy(self) -> "Array":
arrs = tuple(arr.copy() for arr in self.get_constituent_arrays())
return from_constituent_arrays(format=self.format, arrays=arrs, shape=self.shape)

def asformat(self, format: StorageFormat) -> "Array":
from ._ops import asformat

return asformat(self, format=format)

def get_constituent_arrays(self) -> tuple[np.ndarray, ...]:
return self._storage.get_constituent_arrays()
11 changes: 11 additions & 0 deletions sparse/mlir_backend/_common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import ctypes
import functools
import weakref
from collections.abc import Iterable

import mlir_finch.runtime as rt

Expand Down Expand Up @@ -52,3 +53,13 @@ def finalizer(ptr):
ctypes.pythonapi.Py_DecRef(ptr)

weakref.finalize(owner, finalizer, ptr)


def as_shape(x) -> tuple[int]:
if not isinstance(x, Iterable):
x = (x,)

if not all(isinstance(xi, int) for xi in x):
raise TypeError("Shape must be an `int` or tuple of `int`s.")

return tuple(int(xi) for xi in x)
6 changes: 6 additions & 0 deletions sparse/mlir_backend/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@
libc.free.argtypes = [ctypes.c_void_p]
libc.free.restype = None

SHARED_LIBS = []
if DEBUG:
SHARED_LIBS.append(MLIR_C_RUNNER_UTILS)

OPT_LEVEL = 0 if DEBUG else 2

# TODO: remove global state
ctx = Context()

Expand Down
16 changes: 8 additions & 8 deletions sparse/mlir_backend/_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,10 @@ def np_dtype(self) -> np.dtype:
return np.dtype(getattr(np, f"uint{self.bit_width}"))


int8 = UnsignedIntegerDType(bit_width=8)
int16 = UnsignedIntegerDType(bit_width=16)
int32 = UnsignedIntegerDType(bit_width=32)
int64 = UnsignedIntegerDType(bit_width=64)
uint8 = UnsignedIntegerDType(bit_width=8)
uint16 = UnsignedIntegerDType(bit_width=16)
uint32 = UnsignedIntegerDType(bit_width=32)
uint64 = UnsignedIntegerDType(bit_width=64)


@dataclasses.dataclass(eq=True, frozen=True, kw_only=True)
Expand All @@ -89,10 +89,10 @@ def np_dtype(self) -> np.dtype:
return np.dtype(getattr(np, f"int{self.bit_width}"))


uint8 = SignedIntegerDType(bit_width=8)
uint16 = SignedIntegerDType(bit_width=16)
uint32 = SignedIntegerDType(bit_width=32)
uint64 = SignedIntegerDType(bit_width=64)
int8 = SignedIntegerDType(bit_width=8)
int16 = SignedIntegerDType(bit_width=16)
int32 = SignedIntegerDType(bit_width=32)
int64 = SignedIntegerDType(bit_width=64)


intp: SignedIntegerDType = locals()[f"int{_PTR_WIDTH}"]
Expand Down
106 changes: 91 additions & 15 deletions sparse/mlir_backend/_ops.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import ctypes
import math

import mlir_finch.execution_engine
import mlir_finch.passmanager
from mlir_finch import ir
from mlir_finch.dialects import arith, complex, func, linalg, sparse_tensor, tensor

import numpy as np

from ._array import Array
from ._common import fn_cache
from ._core import CWD, DEBUG, SHARED_LIBS, ctx, pm
from ._common import as_shape, fn_cache
from ._core import CWD, DEBUG, OPT_LEVEL, SHARED_LIBS, ctx, pm
from ._dtypes import DType, IeeeComplexFloatingDType, IeeeRealFloatingDType, IntegerDType
from .levels import StorageFormat, _determine_format


@fn_cache
Expand All @@ -17,7 +21,6 @@ def get_add_module(
b_tensor_type: ir.RankedTensorType,
out_tensor_type: ir.RankedTensorType,
dtype: DType,
rank: int,
) -> ir.Module:
with ir.Location.unknown(ctx):
module = ir.Module.create()
Expand All @@ -31,7 +34,7 @@ def get_add_module(
raise RuntimeError(f"Can not add {dtype=}.")

dtype = dtype._get_mlir_type()
ordering = ir.AffineMap.get_permutation(range(rank))
max_rank = out_tensor_type.rank

with ir.InsertionPoint(module.body):

Expand All @@ -42,8 +45,13 @@ def add(a, b):
[out_tensor_type],
[a, b],
[out],
ir.ArrayAttr.get([ir.AffineMapAttr.get(p) for p in (ordering,) * 3]),
ir.ArrayAttr.get([ir.Attribute.parse("#linalg.iterator_type<parallel>")] * rank),
ir.ArrayAttr.get(
[
ir.AffineMapAttr.get(ir.AffineMap.get_minor_identity(max_rank, t.rank))
for t in (a_tensor_type, b_tensor_type, out_tensor_type)
]
),
ir.ArrayAttr.get([ir.Attribute.parse("#linalg.iterator_type<parallel>")] * max_rank),
)
block = generic_op.regions[0].blocks.append(dtype, dtype, dtype)
with ir.InsertionPoint(block):
Expand Down Expand Up @@ -72,7 +80,7 @@ def add(a, b):
if DEBUG:
(CWD / "add_module_opt.mlir").write_text(str(module))

return mlir_finch.execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=SHARED_LIBS)
return mlir_finch.execution_engine.ExecutionEngine(module, opt_level=OPT_LEVEL, shared_libs=SHARED_LIBS)


@fn_cache
Expand All @@ -97,7 +105,7 @@ def reshape(a, shape):
if DEBUG:
(CWD / "reshape_module_opt.mlir").write_text(str(module))

return mlir_finch.execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=SHARED_LIBS)
return mlir_finch.execution_engine.ExecutionEngine(module, opt_level=OPT_LEVEL, shared_libs=SHARED_LIBS)


@fn_cache
Expand Down Expand Up @@ -125,26 +133,94 @@ def broadcast_to(in_tensor):
if DEBUG:
(CWD / "broadcast_to_module_opt.mlir").write_text(str(module))

return mlir_finch.execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=SHARED_LIBS)
return mlir_finch.execution_engine.ExecutionEngine(module, opt_level=OPT_LEVEL, shared_libs=SHARED_LIBS)


@fn_cache
def get_convert_module(
in_tensor_type: ir.RankedTensorType,
out_tensor_type: ir.RankedTensorType,
):
with ir.Location.unknown(ctx):
module = ir.Module.create()

with ir.InsertionPoint(module.body):

@func.FuncOp.from_py_func(in_tensor_type)
def convert(in_tensor):
return sparse_tensor.convert(out_tensor_type, in_tensor)

def add(x1: Array, x2: Array) -> Array:
ret_storage_format = x1.format
convert.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
if DEBUG:
(CWD / "convert_module.mlir").write_text(str(module))
pm.run(module.operation)
if DEBUG:
(CWD / "convert_module.mlir").write_text(str(module))

return mlir_finch.execution_engine.ExecutionEngine(module, opt_level=OPT_LEVEL, shared_libs=SHARED_LIBS)


def add(x1: Array, x2: Array, /) -> Array:
# TODO: Determine output format via autoscheduler
ret_storage_format = _determine_format(x1.format, x2.format, dtype=x1.dtype, union=True)
ret_storage = ret_storage_format._get_ctypes_type(owns_memory=True)()
out_tensor_type = ret_storage_format._get_mlir_type(shape=x1.shape)
out_tensor_type = ret_storage_format._get_mlir_type(shape=np.broadcast_shapes(x1.shape, x2.shape))

# TODO: Decide what will be the output tensor_type
add_module = get_add_module(
x1._get_mlir_type(),
x2._get_mlir_type(),
out_tensor_type=out_tensor_type,
dtype=x1.dtype,
rank=x1.ndim,
)
add_module.invoke(
"add",
ctypes.pointer(ctypes.pointer(ret_storage)),
*x1._to_module_arg(),
*x2._to_module_arg(),
)
return Array(storage=ret_storage, shape=out_tensor_type.shape)
return Array(storage=ret_storage, shape=tuple(out_tensor_type.shape))


def asformat(x: Array, /, format: StorageFormat) -> Array:
if x.format == format:
return x

out_tensor_type = format._get_mlir_type(shape=x.shape)
ret_storage = format._get_ctypes_type(owns_memory=True)()

convert_module = get_convert_module(
x._get_mlir_type(),
out_tensor_type,
)

convert_module.invoke(
"convert",
ctypes.pointer(ctypes.pointer(ret_storage)),
*x._to_module_arg(),
)

return Array(storage=ret_storage, shape=x.shape)


def reshape(x: Array, /, shape: tuple[int, ...]) -> Array:
from ._conversions import _from_numpy

shape = as_shape(shape)
if math.prod(x.shape) != math.prod(shape):
raise ValueError(f"`math.prod(x.shape) != math.prod(shape)`, {x.shape=}, {shape=}")

ret_storage_format = _determine_format(x.format, dtype=x.dtype, union=len(shape) > x.ndim, out_ndim=len(shape))
shape_array = _from_numpy(np.asarray(shape, dtype=np.uint64))
out_tensor_type = ret_storage_format._get_mlir_type(shape=shape)
ret_storage = ret_storage_format._get_ctypes_type(owns_memory=True)()

reshape_module = get_reshape_module(x._get_mlir_type(), shape_array._get_mlir_type(), out_tensor_type)

reshape_module.invoke(
"reshape",
ctypes.pointer(ctypes.pointer(ret_storage)),
*x._to_module_arg(),
*shape_array._to_module_arg(),
)

return Array(storage=ret_storage, shape=shape)
Loading

0 comments on commit 4b491b5

Please sign in to comment.