Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][sparse] unifies sparse_tensor.sort_coo/sort into one operation. #66722

Merged
merged 4 commits into from
Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 13 additions & 62 deletions mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -762,81 +762,32 @@ def SparseTensor_OutOp : SparseTensor_Op<"out", []>,
// Sparse Tensor Sorting Operations.
//===----------------------------------------------------------------------===//

def SparseTensor_SortOp : SparseTensor_Op<"sort", [AttrSizedOperandSegments]>,
Arguments<(ins Index:$n,
Variadic<StridedMemRefRankOf<[AnyInteger, Index], [1]>>:$xs,
Variadic<StridedMemRefRankOf<[AnyType], [1]>>:$ys,
SparseTensorSortKindAttr:$algorithm)> {
string summary = "Sorts the arrays in xs and ys lexicographically on the "
"integral values found in the xs list";
string description = [{
Lexicographically sort the first `n` values in `xs` along with the values in
`ys`. Conceptually, the values being sorted are tuples produced by
`zip(zip(xs), zip(ys))`. In particular, values in `ys` needed to be sorted
along with values in `xs`, but values in `ys` don't affect the
lexicographical order. The order in which arrays appear in `xs` affects the
sorting result. The operator updates `xs` and `ys` in place with the result
of the sorting.

For example, assume x1=[4, 3], x2=[1, 2], y1=[10, 5], then the output of
"sort 2, x1, x2 jointly y1" are x1=[3, 4], x2=[2, 1], y1=[5, 10] while the
output of "sort 2, x2, x1, jointly y1" are x2=[1, 2], x1=[4, 3], y1=[10, 5].

Buffers in `xs` needs to have the same integral element type while buffers
in `ys` can have different numeric element types. All buffers in `xs` and
`ys` should have a dimension not less than `n`. The behavior of the operator
is undefined if this condition is not met. The operator requires at least
one buffer in `xs` while `ys` can be empty.

The enum attribute `algorithm` indicates the sorting algorithm used to
implement the operator: hybrid_quick_sort, insertion_sort_stable,
quick_sort, or heap_sort.

Note that this operation is "impure" in the sense that its behavior is
solely defined by side-effects and not SSA values.

Example:

```mlir
sparse_tensor.sort insertion_sort_stable %n, %x1, %x2 jointly y1, %y2
: memref<?xindex>, memref<?xindex> jointly memref<?xindex>, memref<?xf32>
```

```mlir
sparse_tensor.sort hybrid_quick_sort %n, %x1, %x2 jointly y1, %y2
{ alg=1 : index}
: memref<?xindex>, memref<?xindex> jointly memref<?xindex>, memref<?xf32>
```
}];
let assemblyFormat = "$algorithm $n `,` $xs (`jointly` $ys^)? attr-dict"
"`:` type($xs) (`jointly` type($ys)^)?";
let hasVerifier = 1;
}

def SparseTensor_SortCooOp : SparseTensor_Op<"sort_coo">,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very nice, one op to rule them all!

Arguments<(ins Index:$n, StridedMemRefRankOf<[AnyInteger, Index], [1]>:$xy,
Variadic<StridedMemRefRankOf<[AnyType], [1]>>:$ys,
OptionalAttr<IndexAttr>:$nx, OptionalAttr<IndexAttr>:$ny,
AffineMapAttr:$perm_map, OptionalAttr<IndexAttr>:$ny,
SparseTensorSortKindAttr:$algorithm)> {
let summary = "Sorts the arrays in xs and ys lexicographically on the "
"integral values found in the xs list";
let description = [{
Sparse_tensor.sort_coo is similar to sparse_tensor.sort, except that all the
`xs` values and some `ys` values are put in the linear buffer `xy`. The
optional index attribute `nx` provides the number of `xs` values in `xy`.
When `nx` is not explicitly specified, its value is 1. The optional index
attribute `ny` provides the number of `ys` values in `xy`. When `ny` is not
explicitly specified, its value is 0. This instruction supports a more
efficient way to store the COO definition in sparse tensor type.

The buffer xy should have a dimension not less than n * (nx + ny) while the
Sparse_tensor.sort_coo sort the `xs` values along with some `ys` values
that are put in a single linear buffer `xy`.
The affine map attribute `perm_map` specifies the permutation to be applied on
the `xs` before comparison, the rank of the permutation map
also specifies the number of `xs` values in `xy`.
The optional index attribute `ny` provides the number of `ys` values in `xy`.
When `ny` is not explicitly specified, its value is 0.
This instruction supports a more efficient way to store the COO definition
in sparse tensor type.

The buffer xy should have a dimension not less than n * (rank(perm_map) + ny) while the
buffers in `ys` should have a dimension not less than `n`. The behavior of
the operator is undefined if this condition is not met.

Example:

```mlir
sparse_tensor.sort_coo insertion_sort_stable %n, %x { nx = 2 : index}
sparse_tensor.sort_coo insertion_sort_stable %n, %x { perm_map = affine_map<(i,j) -> (j,i)> }
: memref<?xindex>
```

Expand Down
45 changes: 10 additions & 35 deletions mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1353,48 +1353,22 @@ LogicalResult SelectOp::verify() {
return success();
}

LogicalResult SortOp::verify() {
if (getXs().empty())
return emitError("need at least one xs buffer.");

std::optional<int64_t> n = getConstantIntValue(getN());

Type xtp = getMemRefType(getXs().front()).getElementType();
auto checkTypes = [&](ValueRange operands,
bool checkEleType = true) -> LogicalResult {
for (Value opnd : operands) {
auto mtp = getMemRefType(opnd);
const DynSize sh = mtp.getShape()[0];
// We can't check the size of dynamic dimension at compile-time, but all
// xs and ys should have a dimension not less than n at runtime.
if (n && !ShapedType::isDynamic(sh) && sh < n.value())
return emitError(llvm::formatv("xs and ys need to have a dimension >= n"
": {0} < {1}",
sh, n.value()));

if (checkEleType && xtp != mtp.getElementType())
return emitError("mismatch xs element types");
}
return success();
};
RETURN_FAILURE_IF_FAILED(checkTypes(getXs()))
return n ? checkTypes(getYs(), false) : success();
}

LogicalResult SortCooOp::verify() {
AffineMap xPerm = getPermMap();
uint64_t nx = xPerm.getNumDims();
if (nx < 1)
emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx));

if (!xPerm.isPermutation())
emitError(llvm::formatv("Expected a permutation map, got {0}", xPerm));

std::optional<int64_t> cn = getConstantIntValue(getN());
// We can't check the size of the buffers when n or buffer dimensions aren't
// compile-time constants.
if (!cn)
return success();

uint64_t n = cn.value();
uint64_t nx = 1;
if (auto nxAttr = getNxAttr()) {
nx = nxAttr.getInt();
if (nx < 1)
emitError(llvm::formatv("Expected nx > 1, got {0}", nx));
}
uint64_t ny = 0;
if (auto nyAttr = getNyAttr()) {
ny = nyAttr.getInt();
Expand All @@ -1409,7 +1383,8 @@ LogicalResult SortCooOp::verify() {
emitError(llvm::formatv("{0} got {1} < {2}", message, sh, minSize));
};

checkDim(getXy(), n * (nx + ny), "Expected dimension(xy) >= n * (nx + ny)");
checkDim(getXy(), n * (nx + ny),
"Expected dimension(xy) >= n * (rank(perm_map) + ny)");

for (Value opnd : getYs()) {
checkDim(opnd, n, "Expected dimension(y) >= n");
Expand Down
Loading