Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][sparse] unifies sparse_tensor.sort_coo/sort into one operation. #66722

Merged
merged 4 commits into from
Sep 20, 2023

Conversation

PeimingLiu
Copy link
Member

@PeimingLiu PeimingLiu commented Sep 19, 2023

The use cases of the two operations are largely overlapped, let's simplify it and only use one of them.

@llvmbot llvmbot added mlir:sparse Sparse compiler in MLIR mlir labels Sep 19, 2023
@llvmbot
Copy link
Member

llvmbot commented Sep 19, 2023

@llvm/pr-subscribers-mlir-sparse

@llvm/pr-subscribers-mlir

Changes

The functionality of the two operations are largely overlapped, let's simplify it and only use one of them.


Patch is 91.64 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/66722.diff

14 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td (+13-62)
  • (modified) mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp (+10-35)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp (+142-172)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp (+7-5)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp (-1)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp (+16-22)
  • (modified) mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir (+23-82)
  • (modified) mlir/test/Dialect/SparseTensor/codegen.mlir (+3-3)
  • (modified) mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir (+1-1)
  • (modified) mlir/test/Dialect/SparseTensor/invalid.mlir (+19-28)
  • (modified) mlir/test/Dialect/SparseTensor/roundtrip.mlir (+7-57)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir (+1-1)
  • (removed) mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir (-187)
  • (modified) mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir (+27-24)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 94301dbcd9f7b42..59815fc755ee5f3 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -762,81 +762,32 @@ def SparseTensor_OutOp : SparseTensor_Op<"out", []>,
 // Sparse Tensor Sorting Operations.
 //===----------------------------------------------------------------------===//
 
-def SparseTensor_SortOp : SparseTensor_Op<"sort", [AttrSizedOperandSegments]>,
-    Arguments<(ins Index:$n,
-               Variadic<StridedMemRefRankOf<[AnyInteger, Index], [1]>>:$xs,
-               Variadic<StridedMemRefRankOf<[AnyType], [1]>>:$ys,
-               SparseTensorSortKindAttr:$algorithm)>  {
-  string summary = "Sorts the arrays in xs and ys lexicographically on the "
-                   "integral values found in the xs list";
-  string description = [{
-    Lexicographically sort the first `n` values in `xs` along with the values in
-    `ys`. Conceptually, the values being sorted are tuples produced by
-    `zip(zip(xs), zip(ys))`. In particular, values in `ys` needed to be sorted
-    along with values in `xs`, but values in `ys` don't affect the
-    lexicographical order. The order in which arrays appear in `xs` affects the
-    sorting result. The operator updates `xs` and `ys` in place with the result
-    of the sorting.
-
-    For example, assume x1=[4, 3], x2=[1, 2], y1=[10, 5], then the output of
-    "sort 2, x1, x2 jointly y1" are x1=[3, 4], x2=[2, 1], y1=[5, 10] while the
-    output of "sort 2, x2, x1, jointly y1" are x2=[1, 2], x1=[4, 3], y1=[10, 5].
-
-    Buffers in `xs` needs to have the same integral element type while buffers
-    in `ys` can have different numeric element types. All buffers in `xs` and
-    `ys` should have a dimension not less than `n`. The behavior of the operator
-    is undefined if this condition is not met. The operator requires at least
-    one buffer in `xs` while `ys` can be empty.
-
-    The enum attribute `algorithm` indicates the sorting algorithm used to
-    implement the operator: hybrid_quick_sort, insertion_sort_stable,
-    quick_sort, or heap_sort.
-
-    Note that this operation is "impure" in the sense that its behavior is
-    solely defined by side-effects and not SSA values.
-
-    Example:
-
-    ```mlir
-    sparse_tensor.sort insertion_sort_stable %n, %x1, %x2 jointly y1, %y2
-      : memref<?xindex>, memref<?xindex> jointly memref<?xindex>, memref<?xf32>
-    ```
-
-    ```mlir
-    sparse_tensor.sort hybrid_quick_sort %n, %x1, %x2 jointly y1, %y2
-      { alg=1 : index}
-      : memref<?xindex>, memref<?xindex> jointly memref<?xindex>, memref<?xf32>
-    ```
-  }];
-  let assemblyFormat = "$algorithm $n `,` $xs (`jointly` $ys^)? attr-dict"
-                       "`:` type($xs) (`jointly` type($ys)^)?";
-  let hasVerifier = 1;
-}
-
 def SparseTensor_SortCooOp : SparseTensor_Op<"sort_coo">,
     Arguments<(ins Index:$n, StridedMemRefRankOf<[AnyInteger, Index], [1]>:$xy,
                Variadic<StridedMemRefRankOf<[AnyType], [1]>>:$ys,
-               OptionalAttr<IndexAttr>:$nx, OptionalAttr<IndexAttr>:$ny,
+               AffineMapAttr:$perm_map, OptionalAttr<IndexAttr>:$ny,
                SparseTensorSortKindAttr:$algorithm)>  {
   let summary = "Sorts the arrays in xs and ys lexicographically on the "
                 "integral values found in the xs list";
   let description = [{
-    Sparse_tensor.sort_coo is similar to sparse_tensor.sort, except that all the
-    `xs` values and some `ys` values are put in the linear buffer `xy`. The
-    optional index attribute `nx` provides the number of `xs` values in `xy`.
-    When `nx` is not explicitly specified, its value is 1. The optional index
-    attribute `ny` provides the number of `ys` values in `xy`. When `ny` is not
-    explicitly specified, its value is 0. This instruction supports a more
-    efficient way to store the COO definition in sparse tensor type.
-
-    The buffer xy should have a dimension not less than n * (nx + ny) while the
+    Sparse_tensor.sort_coo sort the `xs` values along with some `ys` values
+    that are put in a single linear buffer `xy`.
+    The affine map attribute `perm_map` specifies the permutation to be applied on
+    the `xs` before comparison, the rank of the permutation map
+    also specifies the number of `xs` values in `xy`.
+    The optional index attribute `ny` provides the number of `ys` values in `xy`.
+    When `ny` is not explicitly specified, its value is 0.
+    This instruction supports a more efficient way to store the COO definition
+    in sparse tensor type.
+
+    The buffer xy should have a dimension not less than n * (rank(perm_map) + ny) while the
     buffers in `ys` should have a dimension not less than `n`. The behavior of
     the operator is undefined if this condition is not met.
 
     Example:
 
     ```mlir
-    sparse_tensor.sort_coo insertion_sort_stable %n, %x { nx = 2 : index}
+    sparse_tensor.sort_coo insertion_sort_stable %n, %x { perm_map = affine_map<(i,j) -> (j,i)> }
       : memref<?xindex>
     ```
 
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index e71d2a8dd623a6a..9675a61109477b5 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1353,35 +1353,15 @@ LogicalResult SelectOp::verify() {
   return success();
 }
 
-LogicalResult SortOp::verify() {
-  if (getXs().empty())
-    return emitError("need at least one xs buffer.");
-
-  std::optional<int64_t> n = getConstantIntValue(getN());
-
-  Type xtp = getMemRefType(getXs().front()).getElementType();
-  auto checkTypes = [&](ValueRange operands,
-                        bool checkEleType = true) -> LogicalResult {
-    for (Value opnd : operands) {
-      auto mtp = getMemRefType(opnd);
-      const DynSize sh = mtp.getShape()[0];
-      // We can't check the size of dynamic dimension at compile-time, but all
-      // xs and ys should have a dimension not less than n at runtime.
-      if (n && !ShapedType::isDynamic(sh) && sh < n.value())
-        return emitError(llvm::formatv("xs and ys need to have a dimension >= n"
-                                       ": {0} < {1}",
-                                       sh, n.value()));
-
-      if (checkEleType && xtp != mtp.getElementType())
-        return emitError("mismatch xs element types");
-    }
-    return success();
-  };
-  RETURN_FAILURE_IF_FAILED(checkTypes(getXs()))
-  return n ? checkTypes(getYs(), false) : success();
-}
-
 LogicalResult SortCooOp::verify() {
+  AffineMap xPerm = getPermMap();
+  uint64_t nx = xPerm.getNumDims();
+  if (nx < 1)
+    emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx));
+
+  if (!xPerm.isPermutation())
+    emitError(llvm::formatv("Expected a permutation map, got {0}", xPerm));
+
   std::optional<int64_t> cn = getConstantIntValue(getN());
   // We can't check the size of the buffers when n or buffer dimensions aren't
   // compile-time constants.
@@ -1389,12 +1369,6 @@ LogicalResult SortCooOp::verify() {
     return success();
 
   uint64_t n = cn.value();
-  uint64_t nx = 1;
-  if (auto nxAttr = getNxAttr()) {
-    nx = nxAttr.getInt();
-    if (nx < 1)
-      emitError(llvm::formatv("Expected nx > 1, got {0}", nx));
-  }
   uint64_t ny = 0;
   if (auto nyAttr = getNyAttr()) {
     ny = nyAttr.getInt();
@@ -1409,7 +1383,8 @@ LogicalResult SortCooOp::verify() {
       emitError(llvm::formatv("{0} got {1} < {2}", message, sh, minSize));
   };
 
-  checkDim(getXy(), n * (nx + ny), "Expected dimension(xy) >= n * (nx + ny)");
+  checkDim(getXy(), n * (nx + ny),
+           "Expected dimension(xy) >= n * (rank(perm_map) + ny)");
 
   for (Value opnd : getYs()) {
     checkDim(opnd, n, "Expected dimension(y) >= n");
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
index 029ecb0708941fe..3181395a474cfec 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
@@ -45,46 +45,43 @@ static constexpr const char kShiftDownFuncNamePrefix[] = "_sparse_shift_down_";
 static constexpr const char kHeapSortFuncNamePrefix[] = "_sparse_heap_sort_";
 static constexpr const char kQuickSortFuncNamePrefix[] = "_sparse_qsort_";
 
-using FuncGeneratorType = function_ref<void(
-    OpBuilder &, ModuleOp, func::FuncOp, uint64_t, uint64_t, bool, uint32_t)>;
+using FuncGeneratorType = function_ref<void(OpBuilder &, ModuleOp, func::FuncOp,
+                                            AffineMap, uint64_t, uint32_t)>;
 
 /// Constructs a function name with this format to facilitate quick sort:
-///   <namePrefix><nx>_<x type>_<y0 type>..._<yn type> for sort
-///   <namePrefix><nx>_<x type>_coo_<ny>_<y0 type>..._<yn type> for sort_coo
+///   <namePrefix><xPerm>_<x type>_<y0 type>..._<yn type> for sort
+///   <namePrefix><xPerm>_<x type>_coo_<ny>_<y0 type>..._<yn type> for sort_coo
 static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream,
-                                         StringRef namePrefix, uint64_t nx,
-                                         uint64_t ny, bool isCoo,
-                                         ValueRange operands) {
-  nameOstream << namePrefix << nx << "_"
-              << getMemRefType(operands[xStartIdx]).getElementType();
+                                         StringRef namePrefix, AffineMap xPerm,
+                                         uint64_t ny, ValueRange operands) {
+  nameOstream << namePrefix;
+  for (auto res : xPerm.getResults())
+    nameOstream << res.cast<AffineDimExpr>().getPosition() << "_";
 
-  if (isCoo)
-    nameOstream << "_coo_" << ny;
+  nameOstream << getMemRefType(operands[xStartIdx]).getElementType();
+  nameOstream << "_coo_" << ny;
 
-  uint64_t yBufferOffset = isCoo ? 1 : nx;
+  constexpr uint64_t yBufferOffset = 1;
   for (Value v : operands.drop_front(xStartIdx + yBufferOffset))
     nameOstream << "_" << getMemRefType(v).getElementType();
 }
 
 /// Looks up a function that is appropriate for the given operands being
 /// sorted, and creates such a function if it doesn't exist yet. The
-/// parameters `nx` and `ny` tell the number of x and y values provided
-/// by the buffer in xStartIdx, and `isCoo` indicates whether the instruction
-/// being processed is a sparse_tensor.sort or sparse_tensor.sort_coo.
+/// parameters `xPerm` and `ny` tell the number of x and y values provided
+/// by the buffer in xStartIdx.
 //
 // All sorting function generators take (lo, hi, xs, ys) in `operands` as
 // parameters for the sorting functions. Other parameters, such as the recursive
 // call depth, are appended to the end of the parameter list as
 // "trailing parameters".
-static FlatSymbolRefAttr
-getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint,
-                         TypeRange resultTypes, StringRef namePrefix,
-                         uint64_t nx, uint64_t ny, bool isCoo,
-                         ValueRange operands, FuncGeneratorType createFunc,
-                         uint32_t nTrailingP = 0) {
+static FlatSymbolRefAttr getMangledSortHelperFunc(
+    OpBuilder &builder, func::FuncOp insertPoint, TypeRange resultTypes,
+    StringRef namePrefix, AffineMap xPerm, uint64_t ny, ValueRange operands,
+    FuncGeneratorType createFunc, uint32_t nTrailingP = 0) {
   SmallString<32> nameBuffer;
   llvm::raw_svector_ostream nameOstream(nameBuffer);
-  getMangledSortHelperFuncName(nameOstream, namePrefix, nx, ny, isCoo,
+  getMangledSortHelperFuncName(nameOstream, namePrefix, xPerm, ny,
                                operands.drop_back(nTrailingP));
 
   ModuleOp module = insertPoint->getParentOfType<ModuleOp>();
@@ -101,7 +98,7 @@ getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint,
         loc, nameOstream.str(),
         FunctionType::get(context, operands.getTypes(), resultTypes));
     func.setPrivate();
-    createFunc(builder, module, func, nx, ny, isCoo, nTrailingP);
+    createFunc(builder, module, func, xPerm, ny, nTrailingP);
   }
 
   return result;
@@ -110,27 +107,19 @@ getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint,
 /// Creates a code block to process each pair of (xs[i], xs[j]) for sorting.
 /// The code to process the value pairs is generated by `bodyBuilder`.
 static void forEachIJPairInXs(
-    OpBuilder &builder, Location loc, ValueRange args, uint64_t nx, uint64_t ny,
-    bool isCoo, function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
-  Value iOffset, jOffset;
-  if (isCoo) {
-    Value cstep = constantIndex(builder, loc, nx + ny);
-    iOffset = builder.create<arith::MulIOp>(loc, args[0], cstep);
-    jOffset = builder.create<arith::MulIOp>(loc, args[1], cstep);
-  }
-  for (uint64_t k = 0; k < nx; k++) {
-    scf::IfOp ifOp;
-    Value i, j, buffer;
-    if (isCoo) {
-      Value ck = constantIndex(builder, loc, k);
-      i = builder.create<arith::AddIOp>(loc, ck, iOffset);
-      j = builder.create<arith::AddIOp>(loc, ck, jOffset);
-      buffer = args[xStartIdx];
-    } else {
-      i = args[0];
-      j = args[1];
-      buffer = args[xStartIdx + k];
-    }
+    OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
+    uint64_t ny,
+    function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
+  Value cstep = constantIndex(builder, loc, xPerm.getNumResults() + ny);
+  Value iOffset = builder.create<arith::MulIOp>(loc, args[0], cstep);
+  Value jOffset = builder.create<arith::MulIOp>(loc, args[1], cstep);
+  for (unsigned k = 0, e = xPerm.getNumResults(); k < e; k++) {
+    unsigned actualK = xPerm.getResult(k).cast<AffineDimExpr>().getPosition();
+    Value ak = constantIndex(builder, loc, actualK);
+    Value i = builder.create<arith::AddIOp>(loc, ak, iOffset);
+    Value j = builder.create<arith::AddIOp>(loc, ak, jOffset);
+    Value buffer = args[xStartIdx];
+
     bodyBuilder(k, i, j, buffer);
   }
 }
@@ -138,21 +127,28 @@ static void forEachIJPairInXs(
 /// Creates a code block to process each pair of (xys[i], xys[j]) for sorting.
 /// The code to process the value pairs is generated by `bodyBuilder`.
 static void forEachIJPairInAllBuffers(
-    OpBuilder &builder, Location loc, ValueRange args, uint64_t nx, uint64_t ny,
-    bool isCoo, function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
-
-  // Create code for the first (nx + ny) buffers. When isCoo==true, these
-  // logical buffers are all from the xy buffer of the sort_coo operator.
-  forEachIJPairInXs(builder, loc, args, nx + ny, 0, isCoo, bodyBuilder);
+    OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
+    uint64_t ny,
+    function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
+
+  // Create code for the first (xPerm + ny) buffers.
+  SmallVector<AffineExpr> exps(xPerm.getResults().begin(),
+                               xPerm.getResults().end());
+  for (unsigned y = 0; y < ny; y++) {
+    exps.push_back(builder.getAffineDimExpr(y + xPerm.getNumResults()));
+  }
+  AffineMap xyPerm = AffineMap::get(exps.size(), 0, exps, builder.getContext());
+  assert(xyPerm.isPermutation());
 
-  uint64_t numHandledBuffers = isCoo ? 1 : nx + ny;
+  forEachIJPairInXs(builder, loc, args, xyPerm, 0, bodyBuilder);
 
+  constexpr uint64_t numHandledBuffers = 1;
   // Create code for the remaining buffers.
   Value i = args[0];
   Value j = args[1];
   for (const auto &arg :
        llvm::enumerate(args.drop_front(xStartIdx + numHandledBuffers))) {
-    bodyBuilder(arg.index() + nx + ny, i, j, arg.value());
+    bodyBuilder(arg.index() + xPerm.getNumResults() + ny, i, j, arg.value());
   }
 }
 
@@ -168,7 +164,7 @@ static void forEachIJPairInAllBuffers(
 //     ...
 //     swap(yn[i], yn[j]);
 static void createSwap(OpBuilder &builder, Location loc, ValueRange args,
-                       uint64_t nx, uint64_t ny, bool isCoo) {
+                       AffineMap xPerm, uint64_t ny) {
   auto swapOnePair = [&](uint64_t unused, Value i, Value j, Value buffer) {
     Value vi = builder.create<memref::LoadOp>(loc, buffer, i);
     Value vj = builder.create<memref::LoadOp>(loc, buffer, j);
@@ -176,20 +172,20 @@ static void createSwap(OpBuilder &builder, Location loc, ValueRange args,
     builder.create<memref::StoreOp>(loc, vi, buffer, j);
   };
 
-  forEachIJPairInAllBuffers(builder, loc, args, nx, ny, isCoo, swapOnePair);
+  forEachIJPairInAllBuffers(builder, loc, args, xPerm, ny, swapOnePair);
 }
 
 /// Creates code to compare all the (xs[i], xs[j]) pairs. The method to compare
 /// each pair is create via `compareBuilder`.
 static Value createInlinedCompareImplementation(
-    OpBuilder &builder, Location loc, ValueRange args, uint64_t nx, uint64_t ny,
-    bool isCoo,
+    OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
+    uint64_t ny,
     function_ref<Value(OpBuilder &, Location, Value, Value, Value, bool, bool)>
         compareBuilder) {
   Value result;
   auto bodyBuilder = [&](uint64_t k, Value i, Value j, Value buffer) {
     bool isFirstDim = (k == 0);
-    bool isLastDim = (k == nx - 1);
+    bool isLastDim = (k == xPerm.getNumResults() - 1);
     Value val =
         compareBuilder(builder, loc, i, j, buffer, isFirstDim, isLastDim);
     if (isFirstDim) {
@@ -202,7 +198,7 @@ static Value createInlinedCompareImplementation(
     }
   };
 
-  forEachIJPairInXs(builder, loc, args, nx, ny, isCoo, bodyBuilder);
+  forEachIJPairInXs(builder, loc, args, xPerm, ny, bodyBuilder);
 
   builder.setInsertionPointAfterValue(result);
   return result;
@@ -252,12 +248,12 @@ static Value createEqCompare(OpBuilder &builder, Location loc, Value i, Value j,
 //     else if (x2[2] != x2[j]))
 //       and so on ...
 static Value createInlinedEqCompare(OpBuilder &builder, Location loc,
-                                    ValueRange args, uint64_t nx, uint64_t ny,
-                                    bool isCoo, uint32_t nTrailingP = 0) {
+                                    ValueRange args, AffineMap xPerm,
+                                    uint64_t ny, uint32_t nTrailingP = 0) {
   // Compare functions don't use trailing parameters.
   (void)nTrailingP;
   assert(nTrailingP == 0);
-  return createInlinedCompareImplementation(builder, loc, args, nx, ny, isCoo,
+  return createInlinedCompareImplementation(builder, loc, args, xPerm, ny,
                                             createEqCompare);
 }
 
@@ -306,12 +302,12 @@ static Value createLessThanCompare(OpBuilder &builder, Location loc, Value i,
 //   else
 //       and so on ...
 static Value createInlinedLessThan(OpBuilder &builder, Location loc,
-                                   ValueRange args, uint64_t nx, uint64_t ny,
-                                   bool isCoo, uint32_t nTrailingP = 0) {
+                                   ValueRange args, AffineMap xPerm,
+                                   uint64_t ny, uint32_t nTrailingP = 0) {
   // Compare functions don't use trailing parameters.
   (void)nTrailingP;
   assert(nTrailingP == 0);
-  return createInlinedCompareImplementation(builder, loc, args, nx, ny, isCoo,
+  return createInlinedCompareImplementation(builder, loc, args, xPerm, ny,
                                             createLessThanCompare);
 }
 
@@ -329,8 +325,8 @@ static Value createInlinedLessThan(OpBuilder &builder, Location loc,
 //   return lo;
 //
 static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
-                                   func::FuncOp func, uint64_t nx, uint64_t ny,
-                                   bool isCoo, uint32_t nTrailingP = 0) {
+                                   func::FuncOp func, AffineMap xPerm,
+                                   uint64_t ny, uint32_t nTrailingP = 0) {
   // Binary search doesn't use trailing parameters.
   (void)nTrailingP;
   assert(nTrailingP == 0);
@@ -368,11 +364,10 @@ static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
 
   // Compare xs[p] < xs[mid].
   SmallVector<Value> compareOperands{p, mid};
-  uint64_t numXBuffers = isCoo ? 1 : nx;
+  constexpr uint64_t numXBuffers = 1;
   compareOp...
[truncated]

"`:` type($xs) (`jointly` type($ys)^)?";
let hasVerifier = 1;
}

def SparseTensor_SortCooOp : SparseTensor_Op<"sort_coo">,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very nice, one op to rule them all!

@PeimingLiu PeimingLiu merged commit bfa3bc4 into llvm:main Sep 20, 2023
3 checks passed
@PeimingLiu PeimingLiu deleted the unify-code-lib branch September 20, 2023 00:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants