diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td index d83d1ba03feb84..59815fc755ee5f 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -765,27 +765,29 @@ def SparseTensor_OutOp : SparseTensor_Op<"out", []>, def SparseTensor_SortCooOp : SparseTensor_Op<"sort_coo">, Arguments<(ins Index:$n, StridedMemRefRankOf<[AnyInteger, Index], [1]>:$xy, Variadic>:$ys, - AffineMapAttr:$nx, OptionalAttr:$ny, + AffineMapAttr:$perm_map, OptionalAttr:$ny, SparseTensorSortKindAttr:$algorithm)> { let summary = "Sorts the arrays in xs and ys lexicographically on the " "integral values found in the xs list"; let description = [{ - Sparse_tensor.sort_coo is similar to sparse_tensor.sort, except that all the - `xs` values and some `ys` values are put in the linear buffer `xy`. The - optional index attribute `nx` provides the number of `xs` values in `xy`. - When `nx` is not explicitly specified, its value is 1. The optional index - attribute `ny` provides the number of `ys` values in `xy`. When `ny` is not - explicitly specified, its value is 0. This instruction supports a more - efficient way to store the COO definition in sparse tensor type. - - The buffer xy should have a dimension not less than n * (nx + ny) while the + Sparse_tensor.sort_coo sort the `xs` values along with some `ys` values + that are put in a single linear buffer `xy`. + The affine map attribute `perm_map` specifies the permutation to be applied on + the `xs` before comparison, the rank of the permutation map + also specifies the number of `xs` values in `xy`. + The optional index attribute `ny` provides the number of `ys` values in `xy`. + When `ny` is not explicitly specified, its value is 0. + This instruction supports a more efficient way to store the COO definition + in sparse tensor type. + + The buffer xy should have a dimension not less than n * (rank(perm_map) + ny) while the buffers in `ys` should have a dimension not less than `n`. The behavior of the operator is undefined if this condition is not met. Example: ```mlir - sparse_tensor.sort_coo insertion_sort_stable %n, %x { nx = 2 : index} + sparse_tensor.sort_coo insertion_sort_stable %n, %x { perm_map = affine_map<(i,j) -> (j,i)> } : memref ``` diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp index 3cd0847bdf7376..9675a61109477b 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -1354,6 +1354,14 @@ LogicalResult SelectOp::verify() { } LogicalResult SortCooOp::verify() { + AffineMap xPerm = getPermMap(); + uint64_t nx = xPerm.getNumDims(); + if (nx < 1) + emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx)); + + if (!xPerm.isPermutation()) + emitError(llvm::formatv("Expected a permutation map, got {0}", xPerm)); + std::optional cn = getConstantIntValue(getN()); // We can't check the size of the buffers when n or buffer dimensions aren't // compile-time constants. @@ -1361,12 +1369,6 @@ LogicalResult SortCooOp::verify() { return success(); uint64_t n = cn.value(); - uint64_t nx = 1; - if (auto nxAttr = getNxAttr()) { - nx = nxAttr.getAffineMap().getNumResults(); - if (nx < 1) - emitError(llvm::formatv("Expected nx > 1, got {0}", nx)); - } uint64_t ny = 0; if (auto nyAttr = getNyAttr()) { ny = nyAttr.getInt(); @@ -1381,7 +1383,8 @@ LogicalResult SortCooOp::verify() { emitError(llvm::formatv("{0} got {1} < {2}", message, sh, minSize)); }; - checkDim(getXy(), n * (nx + ny), "Expected dimension(xy) >= n * (nx + ny)"); + checkDim(getXy(), n * (nx + ny), + "Expected dimension(xy) >= n * (rank(perm_map) + ny)"); for (Value opnd : getYs()) { checkDim(opnd, n, "Expected dimension(y) >= n"); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp index 101bd165cc598b..3181395a474cfe 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp @@ -54,8 +54,11 @@ using FuncGeneratorType = function_ref().getPosition() << "_"; + + nameOstream << getMemRefType(operands[xStartIdx]).getElementType(); nameOstream << "_coo_" << ny; constexpr uint64_t yBufferOffset = 1; @@ -1405,7 +1408,7 @@ struct SortCooRewriter : public OpRewritePattern { xys.push_back(op.getXy()); xys.append(op.getYs().begin(), op.getYs().end()); - auto xPerm = op.getNx(); + auto xPerm = op.getPermMap(); uint64_t ny = 0; if (auto nyAttr = op.getNyAttr()) ny = nyAttr.getInt(); diff --git a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir index 0036bd5c3310b9..c96a55aa1e8b2f 100644 --- a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir +++ b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir @@ -75,123 +75,64 @@ func.func @sparse_push_back_inbound(%arg0: index, %arg1: memref, %arg2: f // ----- -// CHECK-LABEL: func.func private @_sparse_partition_1_i8_f32_index -// CHECK-LABEL: func.func private @_sparse_qsort_1_i8_f32_index -// CHECK-LABEL: func.func @sparse_sort_1d2v_quick -func.func @sparse_sort_1d2v_quick(%arg0: index, %arg1: memref<10xi8>, %arg2: memref, %arg3: memref<10xindex>) - -> (memref<10xi8>, memref, memref<10xindex>) { - sparse_tensor.sort quick_sort %arg0, %arg1 jointly %arg2, %arg3 : memref<10xi8> jointly memref, memref<10xindex> - return %arg1, %arg2, %arg3 : memref<10xi8>, memref, memref<10xindex> -} - -// ----- - -// Only check the generated supporting function now. We have integration test -// to verify correctness of the generated code. -// -// CHECK-DAG: func.func private @_sparse_partition_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> index { -// CHECK-DAG: func.func private @_sparse_qsort_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { -// CHECK-LABEL: func.func @sparse_sort_3d_quick -func.func @sparse_sort_3d_quick(%arg0: index, %arg1: memref<10xindex>, %arg2: memref, %arg3: memref<10xindex>) -> (memref<10xindex>, memref, memref<10xindex>) { - sparse_tensor.sort quick_sort %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref, memref<10xindex> - return %arg1, %arg2, %arg3 : memref<10xindex>, memref, memref<10xindex> -} - -// ----- - -// Only check the generated supporting function now. We have integration test -// to verify correctness of the generated code. -// -// CHECK-DAG: func.func private @_sparse_binary_search_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> index { -// CHECK-DAG: func.func private @_sparse_sort_stable_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { -// CHECK-DAG: func.func private @_sparse_shift_down_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref, %arg5: index) { -// CHECK-DAG: func.func private @_sparse_heap_sort_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { -// CHECK-DAG: func.func private @_sparse_partition_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> index { -// CHECK-DAG: func.func private @_sparse_hybrid_qsort_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref, %arg5: i64) { -// CHECK-LABEL: func.func @sparse_sort_3d_hybrid -func.func @sparse_sort_3d_hybrid(%arg0: index, %arg1: memref<10xindex>, %arg2: memref, %arg3: memref<10xindex>) -> (memref<10xindex>, memref, memref<10xindex>) { - sparse_tensor.sort hybrid_quick_sort %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref, memref<10xindex> - return %arg1, %arg2, %arg3 : memref<10xindex>, memref, memref<10xindex> -} - -// ----- - -// Only check the generated supporting functions. We have integration test to -// verify correctness of the generated code. -// -// CHECK-DAG: func.func private @_sparse_binary_search_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> index { -// CHECK-DAG: func.func private @_sparse_sort_stable_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { -// CHECK-LABEL: func.func @sparse_sort_3d_stable -func.func @sparse_sort_3d_stable(%arg0: index, %arg1: memref<10xindex>, %arg2: memref, %arg3: memref<10xindex>) -> (memref<10xindex>, memref, memref<10xindex>) { - sparse_tensor.sort insertion_sort_stable %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref, memref<10xindex> - return %arg1, %arg2, %arg3 : memref<10xindex>, memref, memref<10xindex> -} - -// ----- +#ID_MAP=affine_map<(d0, d1) -> (d0, d1)> // Only check the generated supporting functions. We have integration test to // verify correctness of the generated code. // -// CHECK-DAG: func.func private @_sparse_shift_down_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref, %arg5: index) { -// CHECK-DAG: func.func private @_sparse_heap_sort_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { -// CHECK-LABEL: func.func @sparse_sort_3d_heap -func.func @sparse_sort_3d_heap(%arg0: index, %arg1: memref<10xindex>, %arg2: memref, %arg3: memref<10xindex>) -> (memref<10xindex>, memref, memref<10xindex>) { - sparse_tensor.sort heap_sort %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref, memref<10xindex> - return %arg1, %arg2, %arg3 : memref<10xindex>, memref, memref<10xindex> -} - -// ----- - -// Only check the generated supporting functions. We have integration test to -// verify correctness of the generated code. -// -// CHECK-DAG: func.func private @_sparse_partition_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> index { -// CHECK-DAG: func.func private @_sparse_qsort_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { +// CHECK-DAG: func.func private @_sparse_partition_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> index { +// CHECK-DAG: func.func private @_sparse_qsort_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { // CHECK-LABEL: func.func @sparse_sort_coo_quick func.func @sparse_sort_coo_quick(%arg0: index, %arg1: memref<100xindex>, %arg2: memref, %arg3: memref<10xi32>) -> (memref<100xindex>, memref, memref<10xi32>) { - sparse_tensor.sort_coo quick_sort %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref, memref<10xi32> + sparse_tensor.sort_coo quick_sort %arg0, %arg1 jointly %arg2, %arg3 {perm_map = #ID_MAP, ny = 1: index} : memref<100xindex> jointly memref, memref<10xi32> return %arg1, %arg2, %arg3 : memref<100xindex>, memref, memref<10xi32> } // ----- +#ID_MAP=affine_map<(d0, d1) -> (d0, d1)> + // Only check the generated supporting functions. We have integration test to // verify correctness of the generated code. // -// CHECK-DAG: func.func private @_sparse_binary_search_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> index { -// CHECK-DAG: func.func private @_sparse_sort_stable_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { -// CHECK-DAG: func.func private @_sparse_shift_down_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref, %arg5: index) { -// CHECK-DAG: func.func private @_sparse_heap_sort_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { -// CHECK-DAG: func.func private @_sparse_partition_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> index { -// CHECK-DAG: func.func private @_sparse_hybrid_qsort_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref, %arg5: i64) { +// CHECK-DAG: func.func private @_sparse_binary_search_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> index { +// CHECK-DAG: func.func private @_sparse_sort_stable_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { +// CHECK-DAG: func.func private @_sparse_shift_down_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref, %arg5: index) { +// CHECK-DAG: func.func private @_sparse_heap_sort_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { +// CHECK-DAG: func.func private @_sparse_partition_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> index { +// CHECK-DAG: func.func private @_sparse_hybrid_qsort_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref, %arg5: i64) { // CHECK-LABEL: func.func @sparse_sort_coo_hybrid func.func @sparse_sort_coo_hybrid(%arg0: index, %arg1: memref<100xindex>, %arg2: memref, %arg3: memref<10xi32>) -> (memref<100xindex>, memref, memref<10xi32>) { - sparse_tensor.sort_coo hybrid_quick_sort %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref, memref<10xi32> + sparse_tensor.sort_coo hybrid_quick_sort %arg0, %arg1 jointly %arg2, %arg3 {perm_map = #ID_MAP, ny = 1: index} : memref<100xindex> jointly memref, memref<10xi32> return %arg1, %arg2, %arg3 : memref<100xindex>, memref, memref<10xi32> } // ----- +#ID_MAP=affine_map<(d0, d1) -> (d0, d1)> + // Only check the generated supporting functions. We have integration test to // verify correctness of the generated code. // -// CHECK-DAG: func.func private @_sparse_binary_search_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> index { -// CHECK-DAG: func.func private @_sparse_sort_stable_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { +// CHECK-DAG: func.func private @_sparse_binary_search_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> index { +// CHECK-DAG: func.func private @_sparse_sort_stable_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { // CHECK-LABEL: func.func @sparse_sort_coo_stable func.func @sparse_sort_coo_stable(%arg0: index, %arg1: memref<100xindex>, %arg2: memref, %arg3: memref<10xi32>) -> (memref<100xindex>, memref, memref<10xi32>) { - sparse_tensor.sort_coo insertion_sort_stable %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref, memref<10xi32> + sparse_tensor.sort_coo insertion_sort_stable %arg0, %arg1 jointly %arg2, %arg3 {perm_map = #ID_MAP, ny = 1: index} : memref<100xindex> jointly memref, memref<10xi32> return %arg1, %arg2, %arg3 : memref<100xindex>, memref, memref<10xi32> } // ----- +#ID_MAP=affine_map<(d0, d1) -> (d0, d1)> + // Only check the generated supporting functions. We have integration test to // verify correctness of the generated code. // -// CHECK-DAG: func.func private @_sparse_shift_down_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref, %arg5: index) { -// CHECK-DAG: func.func private @_sparse_heap_sort_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { +// CHECK-DAG: func.func private @_sparse_shift_down_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref, %arg5: index) { +// CHECK-DAG: func.func private @_sparse_heap_sort_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { // CHECK-LABEL: func.func @sparse_sort_coo_heap func.func @sparse_sort_coo_heap(%arg0: index, %arg1: memref<100xindex>, %arg2: memref, %arg3: memref<10xi32>) -> (memref<100xindex>, memref, memref<10xi32>) { - sparse_tensor.sort_coo heap_sort %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref, memref<10xi32> + sparse_tensor.sort_coo heap_sort %arg0, %arg1 jointly %arg2, %arg3 {perm_map = #ID_MAP, ny = 1: index} : memref<100xindex> jointly memref, memref<10xi32> return %arg1, %arg2, %arg3 : memref<100xindex>, memref, memref<10xi32> } diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir index f1317f23d65684..ea11a98b76ec63 100644 --- a/mlir/test/Dialect/SparseTensor/codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/codegen.mlir @@ -436,7 +436,7 @@ func.func @sparse_expansion3(%arg0: index, %arg1: index) -> memref { // CHECK-DAG: %[[A9:.*]] = arith.constant 0.000000e+00 : f64 // CHECK-DAG: %[[A10:.*]] = arith.constant 1 : index // CHECK-DAG: %[[A11:.*]] = arith.constant 0 : index -// CHECK: sparse_tensor.sort hybrid_quick_sort %[[A7]], %[[A6]] : memref +// CHECK: sparse_tensor.sort_coo hybrid_quick_sort %[[A7]], %[[A6]] // CHECK: %[[A12:.*]]:4 = scf.for %[[A13:.*]] = %[[A11]] to %[[A7]] step %[[A10]] iter_args(%[[A14:.*]] = %[[A0]], %[[A15:.*]] = %[[A1]], %[[A16:.*]] = %[[A2]], %[[A17:.*]] = %[[A3]]) // CHECK: %[[A18:.*]] = memref.load %[[A6]]{{\[}}%[[A13]]] : memref // CHECK: %[[A19:.*]] = memref.load %[[A4]]{{\[}}%[[A18]]] : memref @@ -484,7 +484,7 @@ func.func @sparse_compression_1d(%tensor: tensor<100xf64, #SV>, // CHECK: %[[A11:.*]] = arith.constant 0.000000e+00 : f64 // CHECK: %[[A12:.*]] = arith.constant 1 : index // CHECK: %[[A13:.*]] = arith.constant 0 : index -// CHECK: sparse_tensor.sort hybrid_quick_sort %[[A7]], %[[A6]] : memref +// CHECK: sparse_tensor.sort_coo hybrid_quick_sort %[[A7]], %[[A6]] // CHECK: %[[A14:.*]]:4 = scf.for %[[A15:.*]] = %[[A13]] to %[[A7]] step %[[A12]] iter_args(%[[A16:.*]] = %[[A0]], %[[A17:.*]] = %[[A1]], %[[A18:.*]] = %[[A2]], %[[A19:.*]] = %[[A3]]) -> (memref, memref, memref, !sparse_tensor.storage_specifier // CHECK: %[[A20:.*]] = memref.load %[[A6]]{{\[}}%[[A15]]] : memref // CHECK: %[[A21:.*]] = memref.load %[[A4]]{{\[}}%[[A20]]] : memref @@ -712,7 +712,7 @@ func.func @sparse_convert_element_type(%arg0: tensor<32xf32, #SparseVector>) -> // CHECK: %[[A33:.*]] = call @getSparseTensorReaderReadToBuffers0F32(%[[A5]], %[[A32]], %[[A14]], %[[A15]]) // CHECK: %[[A34:.*]] = arith.cmpi eq, %[[A33]], %[[A1]] : i1 // CHECK: scf.if %[[A34]] { -// CHECK: sparse_tensor.sort_coo hybrid_quick_sort %[[A10]], %[[A14]] jointly %[[A15]] {nx = 2 : index, ny = 0 : index} : memref jointly memref +// CHECK: sparse_tensor.sort_coo hybrid_quick_sort %[[A10]], %[[A14]] jointly %[[A15]] {ny = 0 : index, perm_map = #{{.*}}} : memref jointly memref // CHECK: } // CHECK: memref.store %[[A10]], %[[A27]]{{\[}}%[[A2]]] : memref // CHECK: %[[A36:.*]] = sparse_tensor.storage_specifier.set %[[A30]] crd_mem_sz at 0 with %[[A11]] diff --git a/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir b/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir index b3eb50f1755dac..54cdfc690952d9 100644 --- a/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir +++ b/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir @@ -178,7 +178,7 @@ func.func @sparse_convert_singleton(%arg0: tensor) -> // CHECK-RWT: %[[VAL_16:.*]] = sparse_tensor.load %[[VAL_17:.*]] hasInserts : tensor> // CHECK-RWT: %[[VAL_18:.*]] = sparse_tensor.values %[[VAL_16]] : tensor> to memref // CHECK-RWT: %[[VAL_19:.*]] = sparse_tensor.coordinates_buffer %[[VAL_16]] : tensor> to memref -// CHECK-RWT: sparse_tensor.sort_coo hybrid_quick_sort %[[VAL_7]], %[[VAL_19]] jointly %[[VAL_18]] {nx = 3 : index, ny = 0 : index} +// CHECK-RWT: sparse_tensor.sort_coo hybrid_quick_sort %[[VAL_7]], %[[VAL_19]] jointly %[[VAL_18]] {ny = 0 : index, perm_map = #map} // CHECK-RWT: %[[VAL_20:.*]] = bufferization.alloc_tensor(%[[VAL_4]], %[[VAL_5]], %[[VAL_6]]) size_hint=%[[VAL_7]] // CHECK-RWT: %[[VAL_21:.*]] = sparse_tensor.foreach in %[[VAL_16]] init(%[[VAL_20]]) // CHECK-RWT: ^bb0(%[[VAL_22:.*]]: index, %[[VAL_23:.*]]: index, %[[VAL_24:.*]]: index, %[[VAL_25:.*]]: f32, %[[VAL_26:.*]]: tensor>): diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir index 71e6eebb30261c..c0e813dcde7c57 100644 --- a/mlir/test/Dialect/SparseTensor/invalid.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid.mlir @@ -790,60 +790,51 @@ func.func @sparse_tensor_foreach(%arg0: tensor<2x4xf64, #DCSR>, %arg1: f32) -> ( return } -// ----- - -// TODO: a test case with empty xs doesn't work due to some parser issues. - -func.func @sparse_sort_x_type( %arg0: index, %arg1: memref) { - // expected-error@+1 {{operand #1 must be 1D memref of integer or index values}} - sparse_tensor.sort hybrid_quick_sort %arg0, %arg1: memref -} - -// ----- - -func.func @sparse_sort_dim_too_small(%arg0: memref<10xindex>) { - %i20 = arith.constant 20 : index - // expected-error@+1 {{xs and ys need to have a dimension >= n: 10 < 20}} - sparse_tensor.sort insertion_sort_stable %i20, %arg0 : memref<10xindex> - return -} // ----- -func.func @sparse_sort_mismatch_x_type(%arg0: index, %arg1: memref<10xindex>, %arg2: memref<10xi8>) { - // expected-error@+1 {{mismatch xs element types}} - sparse_tensor.sort hybrid_quick_sort %arg0, %arg1, %arg2 : memref<10xindex>, memref<10xi8> - return -} - -// ----- +#MAP = affine_map<(i,j) -> (i,j)> func.func @sparse_sort_coo_x_type( %arg0: index, %arg1: memref) { // expected-error@+1 {{operand #1 must be 1D memref of integer or index values}} - sparse_tensor.sort_coo insertion_sort_stable %arg0, %arg1: memref + sparse_tensor.sort_coo insertion_sort_stable %arg0, %arg1 {perm_map = #MAP} : memref return } // ----- +#MAP = affine_map<(i,j) -> (i,j)> + func.func @sparse_sort_coo_x_too_small(%arg0: memref<50xindex>) { %i20 = arith.constant 20 : index - // expected-error@+1 {{Expected dimension(xy) >= n * (nx + ny) got 50 < 60}} - sparse_tensor.sort_coo hybrid_quick_sort %i20, %arg0 {nx = 2 : index, ny = 1 : index} : memref<50xindex> + // expected-error@+1 {{Expected dimension(xy) >= n * (rank(perm_map) + ny) got 50 < 60}} + sparse_tensor.sort_coo hybrid_quick_sort %i20, %arg0 {perm_map = #MAP, ny = 1 : index} : memref<50xindex> return } // ----- +#MAP = affine_map<(i,j) -> (i,j)> + func.func @sparse_sort_coo_y_too_small(%arg0: memref<60xindex>, %arg1: memref<10xf32>) { %i20 = arith.constant 20 : index // expected-error@+1 {{Expected dimension(y) >= n got 10 < 20}} - sparse_tensor.sort_coo insertion_sort_stable %i20, %arg0 jointly %arg1 {nx = 2 : index, ny = 1 : index} : memref<60xindex> jointly memref<10xf32> + sparse_tensor.sort_coo insertion_sort_stable %i20, %arg0 jointly %arg1 {perm_map = #MAP, ny = 1 : index} : memref<60xindex> jointly memref<10xf32> return } // ----- +#NON_PERM_MAP = affine_map<(i,j) -> (i,i)> + +func.func @sparse_sort_coo_no_perm(%arg0: index, %arg1: memref) -> (memref) { + // expected-error@+1 {{Expected a permutation map, got (d0, d1) -> (d0, d0)}} + sparse_tensor.sort_coo hybrid_quick_sort %arg0, %arg1 {perm_map = #NON_PERM_MAP, ny = 1 : index}: memref + return %arg1 : memref +} + +// ----- + #CSR = #sparse_tensor.encoding<{map = (d0, d1) -> (d0 : dense, d1 : compressed)}> func.func @sparse_alloc_escapes(%arg0: index) -> tensor<10x?xf64, #CSR> { diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir index d1262cb7aea02d..d252fa559a1543 100644 --- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir @@ -612,79 +612,29 @@ func.func @sparse_tensor_foreach(%arg0: tensor<2x4xf64, #DCSR>, %arg1: f32) -> ( // ----- -// CHECK-LABEL: func @sparse_sort_1d0v( -// CHECK-SAME: %[[A:.*]]: index, -// CHECK-SAME: %[[B:.*]]: memref) -// CHECK: sparse_tensor.sort hybrid_quick_sort %[[A]], %[[B]] : memref -// CHECK: return %[[B]] -func.func @sparse_sort_1d0v(%arg0: index, %arg1: memref) -> (memref) { - sparse_tensor.sort hybrid_quick_sort %arg0, %arg1 : memref - return %arg1 : memref -} - -// ----- - -// CHECK-LABEL: func @sparse_sort_1d2v( -// CHECK-SAME: %[[A:.*]]: index, -// CHECK-SAME: %[[B:.*]]: memref<20xindex>, -// CHECK-SAME: %[[C:.*]]: memref<10xindex>, -// CHECK-SAME: %[[D:.*]]: memref) -// CHECK: sparse_tensor.sort hybrid_quick_sort %[[A]], %[[B]] jointly %[[C]], %[[D]] : memref<20xindex> jointly memref<10xindex>, memref -// CHECK: return %[[B]], %[[C]], %[[D]] -func.func @sparse_sort_1d2v(%arg0: index, %arg1: memref<20xindex>, %arg2: memref<10xindex>, %arg3: memref) -> (memref<20xindex>, memref<10xindex>, memref) { - sparse_tensor.sort hybrid_quick_sort %arg0, %arg1 jointly %arg2, %arg3 : memref<20xindex> jointly memref<10xindex>, memref - return %arg1, %arg2, %arg3 : memref<20xindex>, memref<10xindex>, memref -} - -// ----- - -// CHECK-LABEL: func @sparse_sort_2d1v( -// CHECK-SAME: %[[A:.*]]: index, -// CHECK-SAME: %[[B:.*]]: memref<10xi8>, -// CHECK-SAME: %[[C:.*]]: memref<20xi8>, -// CHECK-SAME: %[[D:.*]]: memref<10xf64>) -// CHECK: sparse_tensor.sort hybrid_quick_sort %[[A]], %[[B]], %[[C]] jointly %[[D]] : memref<10xi8>, memref<20xi8> jointly memref<10xf64> -// CHECK: return %[[B]], %[[C]], %[[D]] -func.func @sparse_sort_2d1v(%arg0: index, %arg1: memref<10xi8>, %arg2: memref<20xi8>, %arg3: memref<10xf64>) -> (memref<10xi8>, memref<20xi8>, memref<10xf64>) { - sparse_tensor.sort hybrid_quick_sort %arg0, %arg1, %arg2 jointly %arg3 : memref<10xi8>, memref<20xi8> jointly memref<10xf64> - return %arg1, %arg2, %arg3 : memref<10xi8>, memref<20xi8>, memref<10xf64> -} - -// ----- - -// CHECK-LABEL: func @sparse_sort_stable( -// CHECK-SAME: %[[A:.*]]: index, -// CHECK-SAME: %[[B:.*]]: memref<10xi8>, -// CHECK-SAME: %[[C:.*]]: memref<20xi8>, -// CHECK-SAME: %[[D:.*]]: memref<10xf64>) -// CHECK: sparse_tensor.sort insertion_sort_stable %[[A]], %[[B]], %[[C]] jointly %[[D]] : memref<10xi8>, memref<20xi8> jointly memref<10xf64> -// CHECK: return %[[B]], %[[C]], %[[D]] -func.func @sparse_sort_stable(%arg0: index, %arg1: memref<10xi8>, %arg2: memref<20xi8>, %arg3: memref<10xf64>) -> (memref<10xi8>, memref<20xi8>, memref<10xf64>) { - sparse_tensor.sort insertion_sort_stable %arg0, %arg1, %arg2 jointly %arg3 : memref<10xi8>, memref<20xi8> jointly memref<10xf64> - return %arg1, %arg2, %arg3 : memref<10xi8>, memref<20xi8>, memref<10xf64> -} - -// ----- +#ID_MAP = affine_map<(i,j) -> (i,j)> // CHECK-LABEL: func @sparse_sort_coo( // CHECK-SAME: %[[A:.*]]: index, // CHECK-SAME: %[[B:.*]]: memref) -// CHECK: sparse_tensor.sort_coo hybrid_quick_sort %[[A]], %[[B]] {nx = 2 : index, ny = 1 : index} : memref +// CHECK: sparse_tensor.sort_coo hybrid_quick_sort %[[A]], %[[B]] {ny = 1 : index, perm_map = #{{.*}}} : memref // CHECK: return %[[B]] func.func @sparse_sort_coo(%arg0: index, %arg1: memref) -> (memref) { - sparse_tensor.sort_coo hybrid_quick_sort %arg0, %arg1 {nx = 2 : index, ny = 1 : index}: memref + sparse_tensor.sort_coo hybrid_quick_sort %arg0, %arg1 {perm_map = #ID_MAP, ny = 1 : index}: memref return %arg1 : memref } // ----- +#ID_MAP = affine_map<(i,j) -> (i,j)> + // CHECK-LABEL: func @sparse_sort_coo_stable( // CHECK-SAME: %[[A:.*]]: index, // CHECK-SAME: %[[B:.*]]: memref, // CHECK-SAME: %[[C:.*]]: memref) -// CHECK: sparse_tensor.sort_coo insertion_sort_stable %[[A]], %[[B]] jointly %[[C]] {nx = 2 : index, ny = 1 : index} +// CHECK: sparse_tensor.sort_coo insertion_sort_stable %[[A]], %[[B]] jointly %[[C]] {ny = 1 : index, perm_map = #{{.*}}} // CHECK: return %[[B]], %[[C]] func.func @sparse_sort_coo_stable(%arg0: index, %arg1: memref, %arg2: memref) -> (memref, memref) { - sparse_tensor.sort_coo insertion_sort_stable %arg0, %arg1 jointly %arg2 {nx = 2 : index, ny = 1 : index}: memref jointly memref + sparse_tensor.sort_coo insertion_sort_stable %arg0, %arg1 jointly %arg2 {perm_map = #ID_MAP, ny = 1 : index}: memref jointly memref return %arg1, %arg2 : memref, memref } diff --git a/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir b/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir index b31ac3ef3a254a..5c308dc3c56234 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir @@ -116,7 +116,7 @@ // CHECK: } {"Emitted from" = "linalg.generic"} // CHECK: scf.yield %[[VAL_64:.*]] : index // CHECK: } {"Emitted from" = "linalg.generic"} -// CHECK: sparse_tensor.sort hybrid_quick_sort %[[VAL_65:.*]], %[[VAL_33]] : memref +// CHECK: sparse_tensor.sort_coo hybrid_quick_sort %[[VAL_65:.*]], %[[VAL_33]] // CHECK: %[[VAL_66:.*]]:4 = scf.for %[[VAL_67:.*]] = %[[VAL_10]] to %[[VAL_65]] step %[[VAL_11]] iter_args(%[[VAL_68:.*]] = %[[VAL_36]], %[[VAL_69:.*]] = %[[VAL_37]], %[[VAL_70:.*]] = %[[VAL_38]], %[[VAL_71:.*]] = %[[VAL_39]]) -> (memref, memref, memref, !sparse_tensor.storage_specifier // CHECK: %[[VAL_72:.*]] = memref.load %[[VAL_32]]{{\[}}%[[VAL_67]]] : memref<4xindex> // CHECK: %[[VAL_73:.*]] = memref.load %[[VAL_30]]{{\[}}%[[VAL_72]]] : memref<4xf64> diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir index 0594b311184f4d..394b9a8448b543 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir @@ -28,7 +28,7 @@ // Do the same run, but now with VLA vectorization. // RUN: %if mlir_arm_sve_tests %{ %{compile_sve} | %{run_sve} | FileCheck %s %} -#ID_MAP = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#ID_MAP = affine_map<(d0, d1, d2) -> (d1, d2, d0)> module { // Stores 5 values to the memref buffer. @@ -96,11 +96,11 @@ module { %y1 = memref.cast %y1s : memref<7xi32> to memref // Sort "parallel arrays". - // CHECK: ( 1, 1, 3, 3, 10 ) - // CHECK: ( 2, 10, 1, 1, 5 ) - // CHECK: ( 4, 2, 9, 9, 7 ) - // CHECK: ( 10, 6, 7, 8, 9 ) - // CHECK: ( 7, 5, 7, 4, 9 ) + // CHECK: ( 1, 1, 2, 5, 10 ) + // CHECK: ( 9, 9, 4, 7, 2 ) + // CHECK: ( 3, 3, 1, 10, 1 ) + // CHECK: ( 7, 8, 10, 9, 6 ) + // CHECK: ( 7, 4, 7, 9, 5 ) call @storeValuesToStrided(%x0, %c1, %c1, %c3, %c10, %c3) : (memref>, i32, i32, i32, i32, i32) -> () call @storeValuesToStrided(%x1, %c10, %c2, %c1, %c5, %c1) @@ -111,24 +111,25 @@ module { : (memref>, i32, i32, i32, i32, i32) -> () call @storeValuesTo(%y1, %c5, %c7, %c4, %c9, %c7) : (memref, i32, i32, i32, i32, i32) -> () - sparse_tensor.sort_coo quick_sort %i5, %xy jointly %y1 {nx = #ID_MAP, ny = 1 : index} + sparse_tensor.sort_coo quick_sort %i5, %xy jointly %y1 {perm_map = #ID_MAP, ny = 1 : index} : memref jointly memref - %x0v = vector.transfer_read %x0[%i0], %c100: memref>, vector<5xi32> - vector.print %x0v : vector<5xi32> + // Dumps memory in the same order as the perm_map such that the output is ordered. %x1v = vector.transfer_read %x1[%i0], %c100: memref>, vector<5xi32> vector.print %x1v : vector<5xi32> %x2v = vector.transfer_read %x2[%i0], %c100: memref>, vector<5xi32> vector.print %x2v : vector<5xi32> + %x0v = vector.transfer_read %x0[%i0], %c100: memref>, vector<5xi32> + vector.print %x0v : vector<5xi32> %y0v = vector.transfer_read %y0[%i0], %c100: memref>, vector<5xi32> vector.print %y0v : vector<5xi32> %y1v = vector.transfer_read %y1[%i0], %c100: memref, vector<5xi32> vector.print %y1v : vector<5xi32> // Stable sort. - // CHECK: ( 1, 1, 3, 3, 10 ) - // CHECK: ( 2, 10, 1, 1, 5 ) - // CHECK: ( 4, 2, 9, 9, 7 ) - // CHECK: ( 10, 6, 8, 7, 9 ) - // CHECK: ( 7, 5, 4, 7, 9 ) + // CHECK: ( 1, 1, 2, 5, 10 ) + // CHECK: ( 9, 9, 4, 7, 2 ) + // CHECK: ( 3, 3, 1, 10, 1 ) + // CHECK: ( 8, 7, 10, 9, 6 ) + // CHECK: ( 4, 7, 7, 9, 5 ) call @storeValuesToStrided(%x0, %c1, %c1, %c3, %c10, %c3) : (memref>, i32, i32, i32, i32, i32) -> () call @storeValuesToStrided(%x1, %c10, %c2, %c1, %c5, %c1) @@ -139,24 +140,24 @@ module { : (memref>, i32, i32, i32, i32, i32) -> () call @storeValuesTo(%y1, %c5, %c7, %c4, %c9, %c7) : (memref, i32, i32, i32, i32, i32) -> () - sparse_tensor.sort_coo insertion_sort_stable %i5, %xy jointly %y1 {nx = #ID_MAP, ny = 1 : index} + sparse_tensor.sort_coo insertion_sort_stable %i5, %xy jointly %y1 {perm_map = #ID_MAP, ny = 1 : index} : memref jointly memref - %x0v2 = vector.transfer_read %x0[%i0], %c100: memref>, vector<5xi32> - vector.print %x0v2 : vector<5xi32> %x1v2 = vector.transfer_read %x1[%i0], %c100: memref>, vector<5xi32> vector.print %x1v2 : vector<5xi32> %x2v2 = vector.transfer_read %x2[%i0], %c100: memref>, vector<5xi32> vector.print %x2v2 : vector<5xi32> + %x0v2 = vector.transfer_read %x0[%i0], %c100: memref>, vector<5xi32> + vector.print %x0v2 : vector<5xi32> %y0v2 = vector.transfer_read %y0[%i0], %c100: memref>, vector<5xi32> vector.print %y0v2 : vector<5xi32> %y1v2 = vector.transfer_read %y1[%i0], %c100: memref, vector<5xi32> vector.print %y1v2 : vector<5xi32> // Heap sort. - // CHECK: ( 1, 1, 3, 3, 10 ) - // CHECK: ( 2, 10, 1, 1, 5 ) - // CHECK: ( 4, 2, 9, 9, 7 ) - // CHECK: ( 10, 6, 8, 7, 9 ) - // CHECK: ( 7, 5, 4, 7, 9 ) + // CHECK: ( 1, 1, 2, 5, 10 ) + // CHECK: ( 9, 9, 4, 7, 2 ) + // CHECK: ( 3, 3, 1, 10, 1 ) + // CHECK: ( 7, 8, 10, 9, 6 ) + // CHECK: ( 7, 4, 7, 9, 5 ) call @storeValuesToStrided(%x0, %c1, %c1, %c3, %c10, %c3) : (memref>, i32, i32, i32, i32, i32) -> () call @storeValuesToStrided(%x1, %c10, %c2, %c1, %c5, %c1) @@ -167,14 +168,14 @@ module { : (memref>, i32, i32, i32, i32, i32) -> () call @storeValuesTo(%y1, %c5, %c7, %c4, %c9, %c7) : (memref, i32, i32, i32, i32, i32) -> () - sparse_tensor.sort_coo heap_sort %i5, %xy jointly %y1 {nx = #ID_MAP, ny = 1 : index} + sparse_tensor.sort_coo heap_sort %i5, %xy jointly %y1 {perm_map = #ID_MAP, ny = 1 : index} : memref jointly memref - %x0v3 = vector.transfer_read %x0[%i0], %c100: memref>, vector<5xi32> - vector.print %x0v3 : vector<5xi32> %x1v3 = vector.transfer_read %x1[%i0], %c100: memref>, vector<5xi32> vector.print %x1v3 : vector<5xi32> %x2v3 = vector.transfer_read %x2[%i0], %c100: memref>, vector<5xi32> vector.print %x2v3 : vector<5xi32> + %x0v3 = vector.transfer_read %x0[%i0], %c100: memref>, vector<5xi32> + vector.print %x0v3 : vector<5xi32> %y0v3 = vector.transfer_read %y0[%i0], %c100: memref>, vector<5xi32> vector.print %y0v3 : vector<5xi32> %y1v3 = vector.transfer_read %y1[%i0], %c100: memref, vector<5xi32>