From d511c5cbba808ddf48a56f30b50bd691edadfae1 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Mateusz=20Sok=C3=B3=C5=82?= <mat646@gmail.com>
Date: Wed, 23 Oct 2024 11:58:00 +0000
Subject: [PATCH] ENH: Update MLIR backend to LLVM 20.dev

---
 .github/workflows/ci.yml                 |  4 ++-
 ci/environment.yml                       |  4 +--
 sparse/mlir_backend/__init__.py          |  4 +--
 sparse/mlir_backend/_common.py           |  2 +-
 sparse/mlir_backend/_conversions.py      | 18 +++++++-----
 sparse/mlir_backend/_core.py             | 18 ++++++++++--
 sparse/mlir_backend/_dtypes.py           |  4 +--
 sparse/mlir_backend/_ops.py              | 22 ++++++++-------
 sparse/mlir_backend/levels.py            | 20 +++++++++++--
 sparse/mlir_backend/tests/test_simple.py | 36 ++++++++++++------------
 10 files changed, 84 insertions(+), 48 deletions(-)

diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 636500b9..6277c3e3 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -34,8 +34,10 @@ jobs:
     steps:
       - name: Checkout Repo
         uses: actions/checkout@v4
-      - uses: mamba-org/setup-micromamba@v2
+      - uses: mamba-org/setup-micromamba@v1.9.0
         with:
+          # NOTE: https://github.com/mamba-org/setup-micromamba/issues/227
+          micromamba-version: '1.5.10-0'
           environment-file: ci/environment.yml
           init-shell: >-
             bash
diff --git a/ci/environment.yml b/ci/environment.yml
index cb49a1e5..e86b34b7 100644
--- a/ci/environment.yml
+++ b/ci/environment.yml
@@ -12,7 +12,7 @@ dependencies:
   - pytest
   - pytest-cov
   - pytest-xdist
-  - mlir-python-bindings==19.*
   - pip:
-    - finch-tensor >=0.1.31
+    - finch-tensor>=0.1.31
+    - finch-mlir>=0.0.2
     - pytest-codspeed
diff --git a/sparse/mlir_backend/__init__.py b/sparse/mlir_backend/__init__.py
index eb299d77..20a02beb 100644
--- a/sparse/mlir_backend/__init__.py
+++ b/sparse/mlir_backend/__init__.py
@@ -1,7 +1,7 @@
 try:
-    import mlir  # noqa: F401
+    import mlir_finch  # noqa: F401
 
-    del mlir
+    del mlir_finch
 except ModuleNotFoundError as e:
     raise ImportError(
         "MLIR Python bindings not installed. Run "
diff --git a/sparse/mlir_backend/_common.py b/sparse/mlir_backend/_common.py
index 84b639cb..b382a822 100644
--- a/sparse/mlir_backend/_common.py
+++ b/sparse/mlir_backend/_common.py
@@ -2,7 +2,7 @@
 import functools
 import weakref
 
-import mlir.runtime as rt
+import mlir_finch.runtime as rt
 
 import numpy as np
 
diff --git a/sparse/mlir_backend/_conversions.py b/sparse/mlir_backend/_conversions.py
index 08f56ea7..743cacb1 100644
--- a/sparse/mlir_backend/_conversions.py
+++ b/sparse/mlir_backend/_conversions.py
@@ -94,13 +94,17 @@ def _from_scipy(arr: ScipySparseArray, copy: bool | None = None) -> Array:
         case "coo":
             if copy is not None and not copy:
                 raise RuntimeError(f"`scipy.sparse.{type(arr.__name__)}` cannot be zero-copy converted.")
-            coords = np.stack([arr.row, arr.col], axis=1)
+            row, col = arr.row, arr.col
+            if row.dtype != col.dtype:
+                raise RuntimeError(f"`row` and `col` dtypes must be the same: {row.dtype} != {col.dtype}.")
             pos = np.array([0, arr.nnz], dtype=np.int64)
             pos_width = pos.dtype.itemsize * 8
-            crd_width = coords.dtype.itemsize * 8
+            crd_width = row.dtype.itemsize * 8
             data = arr.data
             if copy:
-                data = arr.data.copy()
+                data = data.copy()
+                row = row.copy()
+                col = col.copy()
 
             level_props = LevelProperties(0)
             if not arr.has_canonical_format:
@@ -109,7 +113,7 @@ def _from_scipy(arr: ScipySparseArray, copy: bool | None = None) -> Array:
             coo_format = get_storage_format(
                 levels=(
                     Level(LevelFormat.Compressed, level_props | LevelProperties.NonUnique),
-                    Level(LevelFormat.Singleton, level_props),
+                    Level(LevelFormat.Singleton, level_props | LevelProperties.SOA),
                 ),
                 order=(0, 1),
                 pos_width=pos_width,
@@ -117,7 +121,7 @@ def _from_scipy(arr: ScipySparseArray, copy: bool | None = None) -> Array:
                 dtype=arr.dtype,
             )
 
-            return from_constituent_arrays(format=coo_format, arrays=(pos, coords, data), shape=arr.shape)
+            return from_constituent_arrays(format=coo_format, arrays=(pos, row, col, data), shape=arr.shape)
         case _:
             raise NotImplementedError(f"No conversion implemented for `scipy.sparse.{type(arr.__name__)}`.")
 
@@ -133,8 +137,8 @@ def to_scipy(arr: Array) -> ScipySparseArray:
                 return sps.csr_array((data, indices, indptr), shape=arr.shape)
             return sps.csc_array((data, indices, indptr), shape=arr.shape)
         case (Level(LevelFormat.Compressed, _), Level(LevelFormat.Singleton, _)):
-            _, coords, data = arr.get_constituent_arrays()
-            return sps.coo_array((data, (coords[:, 0], coords[:, 1])), shape=arr.shape)
+            _, row, col, data = arr.get_constituent_arrays()
+            return sps.coo_array((data, (row, col)), shape=arr.shape)
         case _:
             raise RuntimeError(f"No conversion implemented for `{storage_format=}`.")
 
diff --git a/sparse/mlir_backend/_core.py b/sparse/mlir_backend/_core.py
index 16e6720b..9488ea46 100644
--- a/sparse/mlir_backend/_core.py
+++ b/sparse/mlir_backend/_core.py
@@ -2,14 +2,28 @@
 import ctypes.util
 import os
 import pathlib
+import sys
 
-from mlir.ir import Context
-from mlir.passmanager import PassManager
+from mlir_finch.ir import Context
+from mlir_finch.passmanager import PassManager
 
 DEBUG = bool(int(os.environ.get("DEBUG", "0")))
 CWD = pathlib.Path(".")
 
+finch_lib_path = f"{sys.prefix}/lib/python3.{sys.version_info.minor}/site-packages/lib"
+
+ld_library_path = os.environ.get("LD_LIBRARY_PATH")
+ld_library_path = f"{finch_lib_path}:{ld_library_path}" if ld_library_path is None else finch_lib_path
+os.environ["LD_LIBRARY_PATH"] = ld_library_path
+
 MLIR_C_RUNNER_UTILS = ctypes.util.find_library("mlir_c_runner_utils")
+if os.name == "posix" and MLIR_C_RUNNER_UTILS is not None:
+    MLIR_C_RUNNER_UTILS = f"{finch_lib_path}/{MLIR_C_RUNNER_UTILS}"
+
+SHARED_LIBS = []
+if MLIR_C_RUNNER_UTILS is not None:
+    SHARED_LIBS.append(MLIR_C_RUNNER_UTILS)
+
 libc = ctypes.CDLL(ctypes.util.find_library("c")) if os.name != "nt" else ctypes.cdll.msvcrt
 libc.free.argtypes = [ctypes.c_void_p]
 libc.free.restype = None
diff --git a/sparse/mlir_backend/_dtypes.py b/sparse/mlir_backend/_dtypes.py
index 0c48d4b0..31d8c5f8 100644
--- a/sparse/mlir_backend/_dtypes.py
+++ b/sparse/mlir_backend/_dtypes.py
@@ -3,8 +3,8 @@
 import math
 import sys
 
-import mlir.runtime as rt
-from mlir import ir
+import mlir_finch.runtime as rt
+from mlir_finch import ir
 
 import numpy as np
 
diff --git a/sparse/mlir_backend/_ops.py b/sparse/mlir_backend/_ops.py
index 180af401..20eee897 100644
--- a/sparse/mlir_backend/_ops.py
+++ b/sparse/mlir_backend/_ops.py
@@ -1,13 +1,13 @@
 import ctypes
 
-import mlir.execution_engine
-import mlir.passmanager
-from mlir import ir
-from mlir.dialects import arith, complex, func, linalg, sparse_tensor, tensor
+import mlir_finch.execution_engine
+import mlir_finch.passmanager
+from mlir_finch import ir
+from mlir_finch.dialects import arith, complex, func, linalg, sparse_tensor, tensor
 
 from ._array import Array
 from ._common import fn_cache
-from ._core import CWD, DEBUG, MLIR_C_RUNNER_UTILS, ctx, pm
+from ._core import CWD, DEBUG, SHARED_LIBS, ctx, pm
 from ._dtypes import DType, IeeeComplexFloatingDType, IeeeRealFloatingDType, IntegerDType
 
 
@@ -37,7 +37,7 @@ def get_add_module(
 
             @func.FuncOp.from_py_func(a_tensor_type, b_tensor_type)
             def add(a, b):
-                out = tensor.empty(out_tensor_type, [])
+                out = tensor.empty(out_tensor_type.shape, dtype, encoding=out_tensor_type.encoding)
                 generic_op = linalg.GenericOp(
                     [out_tensor_type],
                     [a, b],
@@ -72,7 +72,7 @@ def add(a, b):
         if DEBUG:
             (CWD / "add_module_opt.mlir").write_text(str(module))
 
-    return mlir.execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=[MLIR_C_RUNNER_UTILS])
+    return mlir_finch.execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=SHARED_LIBS)
 
 
 @fn_cache
@@ -97,7 +97,7 @@ def reshape(a, shape):
             if DEBUG:
                 (CWD / "reshape_module_opt.mlir").write_text(str(module))
 
-    return mlir.execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=[MLIR_C_RUNNER_UTILS])
+    return mlir_finch.execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=SHARED_LIBS)
 
 
 @fn_cache
@@ -113,7 +113,9 @@ def get_broadcast_to_module(
 
             @func.FuncOp.from_py_func(in_tensor_type)
             def broadcast_to(in_tensor):
-                out = tensor.empty(out_tensor_type, [])
+                out = tensor.empty(
+                    out_tensor_type.shape, out_tensor_type.element_type, encoding=out_tensor_type.encoding
+                )
                 return linalg.broadcast(in_tensor, outs=[out], dimensions=dimensions)
 
             broadcast_to.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
@@ -123,7 +125,7 @@ def broadcast_to(in_tensor):
             if DEBUG:
                 (CWD / "broadcast_to_module_opt.mlir").write_text(str(module))
 
-    return mlir.execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=[MLIR_C_RUNNER_UTILS])
+    return mlir_finch.execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=SHARED_LIBS)
 
 
 def add(x1: Array, x2: Array) -> Array:
diff --git a/sparse/mlir_backend/levels.py b/sparse/mlir_backend/levels.py
index 0eef625a..6559ec4c 100644
--- a/sparse/mlir_backend/levels.py
+++ b/sparse/mlir_backend/levels.py
@@ -5,8 +5,8 @@
 import re
 import typing
 
-from mlir import ir
-from mlir.dialects import sparse_tensor
+from mlir_finch import ir
+from mlir_finch.dialects import sparse_tensor
 
 import numpy as np
 
@@ -36,6 +36,7 @@ def _camel_to_snake(name: str) -> str:
 class LevelProperties(enum.Flag):
     NonOrdered = enum.auto()
     NonUnique = enum.auto()
+    SOA = enum.auto()
 
     def build(self) -> list[sparse_tensor.LevelProperty]:
         return [getattr(sparse_tensor.LevelProperty, _camel_to_snake(p.name)) for p in type(self) if p in self]
@@ -108,15 +109,28 @@ def _get_ctypes_type(self, *, owns_memory=False):
         def get_fields():
             fields = []
             compressed_counter = 0
+            singleton_counter = 0
             for level, next_level in itertools.zip_longest(self.levels, self.levels[1:]):
                 if LevelFormat.Compressed == level.format:
                     compressed_counter += 1
                     fields.append((f"pointers_to_{compressed_counter}", get_nd_memref_descr(1, ptr_dtype)))
                     if next_level is not None and LevelFormat.Singleton == next_level.format:
-                        fields.append((f"indices_{compressed_counter}", get_nd_memref_descr(2, idx_dtype)))
+                        singleton_counter += 1
+                        fields.append(
+                            (
+                                f"indices_{compressed_counter}_coords_{singleton_counter}",
+                                get_nd_memref_descr(1, idx_dtype),
+                            )
+                        )
                     else:
                         fields.append((f"indices_{compressed_counter}", get_nd_memref_descr(1, idx_dtype)))
 
+                if LevelFormat.Singleton == level.format:
+                    singleton_counter += 1
+                    fields.append(
+                        (f"indices_{compressed_counter}_coords_{singleton_counter}", get_nd_memref_descr(1, idx_dtype))
+                    )
+
             fields.append(("values", get_nd_memref_descr(1, self.dtype)))
             return fields
 
diff --git a/sparse/mlir_backend/tests/test_simple.py b/sparse/mlir_backend/tests/test_simple.py
index d932d216..f8ae1f31 100644
--- a/sparse/mlir_backend/tests/test_simple.py
+++ b/sparse/mlir_backend/tests/test_simple.py
@@ -176,20 +176,18 @@ def test_add(rng, dtype):
     assert isinstance(actual, np.ndarray)
     np.testing.assert_array_equal(actual, expected)
 
-    # NOTE: Fixed in https://github.com/llvm/llvm-project/pull/108615
-    # actual = sparse.add(c_tensor, c_tensor).to_scipy_sparse()
-    # expected = c + c
-    # assert isinstance(actual, np.ndarray)
-    # np.testing.assert_array_equal(actual, expected)
+    actual = sparse.to_numpy(sparse.add(dense_tensor, dense_tensor))
+    expected = dense + dense
+    assert isinstance(actual, np.ndarray)
+    np.testing.assert_array_equal(actual, expected)
 
     actual = sparse.to_scipy(sparse.add(csr_2_tensor, coo_tensor))
     expected = csr_2 + coo
     assert_csx_equal(expected, actual)
 
-    # NOTE: https://discourse.llvm.org/t/passmanager-fails-on-simple-coo-addition-example/81247
-    # actual = sparse.add(d_tensor, d_tensor).to_scipy_sparse()
-    # expected = d + d
-    # np.testing.assert_array_equal(actual.todense(), expected.todense())
+    actual = sparse.to_scipy(sparse.add(coo_tensor, coo_tensor))
+    expected = coo + coo
+    np.testing.assert_array_equal(actual.todense(), expected.todense())
 
 
 @parametrize_dtypes
@@ -226,8 +224,11 @@ def test_coo_3d_format(dtype):
     format = sparse.levels.get_storage_format(
         levels=(
             sparse.levels.Level(sparse.levels.LevelFormat.Compressed, sparse.levels.LevelProperties.NonUnique),
-            sparse.levels.Level(sparse.levels.LevelFormat.Singleton, sparse.levels.LevelProperties.NonUnique),
-            sparse.levels.Level(sparse.levels.LevelFormat.Singleton, sparse.levels.LevelProperties(0)),
+            sparse.levels.Level(
+                sparse.levels.LevelFormat.Singleton,
+                sparse.levels.LevelProperties.NonUnique | sparse.levels.LevelProperties.SOA,
+            ),
+            sparse.levels.Level(sparse.levels.LevelFormat.Singleton, sparse.levels.LevelProperties.SOA),
         ),
         order="C",
         pos_width=64,
@@ -237,20 +238,19 @@ def test_coo_3d_format(dtype):
 
     SHAPE = (2, 2, 4)
     pos = np.array([0, 7])
-    crd = np.array([[0, 1, 0, 0, 1, 1, 0], [1, 3, 1, 0, 0, 1, 0], [3, 1, 1, 0, 1, 1, 1]])
+    crd = [np.array([0, 1, 0, 0, 1, 1, 0]), np.array([1, 3, 1, 0, 0, 1, 0]), np.array([3, 1, 1, 0, 1, 1, 1])]
     data = np.array([1, 2, 3, 4, 5, 6, 7], dtype=dtype)
-    carrs = (pos, crd, data)
+    carrs = (pos, *crd, data)
 
     coo_array = sparse.from_constituent_arrays(format=format, arrays=carrs, shape=SHAPE)
     result = coo_array.get_constituent_arrays()
     for actual, expected in zip(result, carrs, strict=True):
         np.testing.assert_array_equal(actual, expected)
 
-    # NOTE: Blocked by https://github.com/llvm/llvm-project/pull/109135
-    # res_arrays = sparse.add(coo_array, coo_array).get_constituent_arrays()
-    # res_expected = (pos, crd, data * 2)
-    # for actual, expected in zip(res_arrays, res_expected, strict=False):
-    #     np.testing.assert_array_equal(actual, expected)
+    result_arrays = sparse.add(coo_array, coo_array).get_constituent_arrays()
+    constituent_arrays = (pos, *crd, data * 2)
+    for actual, expected in zip(result_arrays, constituent_arrays, strict=False):
+        np.testing.assert_array_equal(actual, expected)
 
 
 @parametrize_dtypes