diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/SparseTensor/IR/CMakeLists.txt index 25a2e4869cc78..54ad9491cce51 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/CMakeLists.txt @@ -12,3 +12,9 @@ set(LLVM_TARGET_DEFINITIONS SparseTensorTypes.td) mlir_tablegen(SparseTensorTypes.h.inc -gen-typedef-decls) mlir_tablegen(SparseTensorTypes.cpp.inc -gen-typedef-defs) add_public_tablegen_target(MLIRSparseTensorTypesIncGen) + +set(LLVM_TARGET_DEFINITIONS SparseTensorInterfaces.td) +mlir_tablegen(SparseTensorInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(SparseTensorInterfaces.cpp.inc -gen-op-interface-defs) +add_public_tablegen_target(MLIRSparseTensorInterfacesIncGen) +add_dependencies(mlir-headers MLIRSparseTensorInterfacesIncGen) diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h index 3eb9ce010cb00..cbca0a7f8cc0e 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h @@ -11,6 +11,7 @@ #include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/SparseTensor/IR/Enums.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h new file mode 100644 index 0000000000000..ebbc522123a59 --- /dev/null +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h @@ -0,0 +1,31 @@ +//===- SparseTensorInterfaces.h - sparse tensor operations +//interfaces-------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORINTERFACES_H_ +#define MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORINTERFACES_H_ + +#include "mlir/IR/OpDefinition.h" + +namespace mlir { +class PatternRewriter; + +namespace sparse_tensor { +class StageWithSortSparseOp; + +namespace detail { +LogicalResult stageWithSortImpl(sparse_tensor::StageWithSortSparseOp op, + PatternRewriter &rewriter); +} // namespace detail +} // namespace sparse_tensor +} // namespace mlir + +/// Include the generated interface declarations. +#include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h.inc" + +#endif // MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORINTERFACES_H_ diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td new file mode 100644 index 0000000000000..1379363ff75f4 --- /dev/null +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td @@ -0,0 +1,45 @@ +//===- SparseTensorInterfaces.td --------------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef SPARSETENSOR_IR_SPARSETENSORINTERFACES +#define SPARSETENSOR_IR_SPARSETENSORINTERFACES + +include "mlir/IR/OpBase.td" + +def StageWithSortSparseOpInterface : OpInterface<"StageWithSortSparseOp"> { + let description = [{ + A stage-with-sort sparse tensor operation is an operation that produces + unordered intermediate output. An extra sort is required to obtain the final + ordered result. + + E.g., convert csr -> csc need to be implemented as + convert csr -> unordered coo -> sort by column -> csc; and + concatenate csr, csc -> csr can be staged into + concatenate csr, csr -> unordered coo -> sort by row -> csr. + }]; + let cppNamespace = "::mlir::sparse_tensor"; + let methods = [ + InterfaceMethod< + /*desc=*/"Return true if the operation needs an extra sort to produce the final result.", + /*retTy=*/"bool", + /*methodName=*/"needsExtraSort", + /*args=*/(ins), + /*methodBody=*/"">, + InterfaceMethod< + /*desc=*/"Stage the operation, return the final result value after staging.", + /*retTy=*/"::mlir::LogicalResult", + /*methodName=*/"stageWithSort", + /*args=*/(ins "::mlir::PatternRewriter &":$rewriter), + /*methodBody=*/[{ + return detail::stageWithSortImpl($_op, rewriter); + }]>, + ]; +} + + +#endif // SPARSETENSOR_IR_SPARSETENSORINTERFACES diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td index 9016634fa3be8..3d1807094797e 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -12,6 +12,7 @@ include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td" include "mlir/Dialect/SparseTensor/IR/SparseTensorBase.td" include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td" +include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" @@ -153,7 +154,7 @@ def SparseTensor_DisassembleOp : SparseTensor_Op<"disassemble", [Pure, SameVaria } def SparseTensor_ConvertOp : SparseTensor_Op<"convert", - [Pure]>, + [Pure, StageWithSortSparseOpInterface]>, Arguments<(ins AnyTensor:$source)>, Results<(outs AnyTensor:$dest)> { string summary = "Converts between different tensor types"; @@ -197,9 +198,9 @@ def SparseTensor_ConvertOp : SparseTensor_Op<"convert", }]; let extraClassDeclaration = [{ - // Whether the convert can be done by a single step (either a sort or a foreach), - // or it would require a tmp buffer (sort, then foreach). - bool directConvertable(); + // Whether the convert can be done by a single step or it would require + // an extra sort. Inherited from StageWithSortSparseOpInterface. + bool needsExtraSort(); }]; let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)"; @@ -334,7 +335,8 @@ def SparseTensor_NumberOfEntriesOp : SparseTensor_Op<"number_of_entries", [Pure] let assemblyFormat = "$tensor attr-dict `:` type($tensor)"; } -def SparseTensor_ConcatenateOp : SparseTensor_Op<"concatenate", [Pure]>, +def SparseTensor_ConcatenateOp : SparseTensor_Op<"concatenate", + [Pure, StageWithSortSparseOpInterface]>, Arguments<(ins Variadic:$inputs, DimensionAttr:$dimension)>, Results<(outs AnyRankedTensor:$result)> { @@ -357,6 +359,12 @@ def SparseTensor_ConcatenateOp : SparseTensor_Op<"concatenate", [Pure]>, ``` }]; + let extraClassDeclaration = [{ + // Whether the concatenate can be done by a single step or it would require + // an extra sort. Inherited from StageWithSortSparseOpInterface. + bool needsExtraSort(); + }]; + let assemblyFormat = "$inputs attr-dict `:` type($inputs) `to` type($result)"; let hasVerifier = 1; } diff --git a/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt index b22194d45062a..dd6f1037f71b5 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt @@ -29,6 +29,7 @@ endif() add_mlir_dialect_library(MLIRSparseTensorDialect SparseTensorDialect.cpp + SparseTensorInterfaces.cpp Detail/Var.cpp Detail/DimLvlMap.cpp Detail/LvlTypeParser.cpp diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp index 61522fb0dcd24..cd1e585438dda 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -1065,18 +1065,18 @@ OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) { return {}; } -bool ConvertOp::directConvertable() { +bool ConvertOp::needsExtraSort() { SparseTensorType srcStt = getSparseTensorType(getSource()); SparseTensorType dstStt = getSparseTensorType(getDest()); - // We can always directly convert to unordered sparse tensor or dense tensor - // since dense tensor support random access. + // We do not need an extra sort when returning unordered sparse tensors or + // dense tensor since dense tensor support random access. if (dstStt.isAllDense() || !dstStt.isAllOrdered()) - return true; + return false; if (srcStt.isAllOrdered() && dstStt.isAllOrdered() && srcStt.hasSameDimToLvl(dstStt)) { - return true; + return false; } // Source and dest tensors are ordered in different ways. We only do direct @@ -1086,9 +1086,9 @@ bool ConvertOp::directConvertable() { // performance. if (auto constOp = getSource().getDefiningOp()) if (isa(constOp.getValue())) - return true; + return false; - return false; + return true; } LogicalResult ToPositionsOp::verify() { @@ -1248,6 +1248,23 @@ LogicalResult UnaryOp::verify() { return success(); } +bool ConcatenateOp::needsExtraSort() { + SparseTensorType dstStt = getSparseTensorType(*this); + if (dstStt.isAllDense() || !dstStt.isAllOrdered()) + return false; + + bool allSameOrdered = llvm::all_of(getInputs(), [dstStt](Value op) { + return getSparseTensorType(op).hasSameDimToLvl(dstStt); + }); + // TODO: When conDim != 0, as long as conDim corresponding to the first level + // in all input/output buffers, and all input/output buffers have the same + // dimToLvl, the tmp COO buffer is still unnecessary (e.g, concatenate + // CSC matrices along column). + bool directLowerable = + allSameOrdered && getDimension() == 0 && dstStt.isIdentity(); + return !directLowerable; +} + LogicalResult ConcatenateOp::verify() { const auto dstTp = getSparseTensorType(*this); const Dimension concatDim = getDimension(); diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp new file mode 100644 index 0000000000000..d8769eacc44f3 --- /dev/null +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp @@ -0,0 +1,55 @@ +//===- SparseTensorInterfaces.cpp - SparseTensor interfaces impl ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" +#include "mlir/IR/PatternMatch.h" + +using namespace mlir; +using namespace mlir::sparse_tensor; + +#include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp.inc" + +LogicalResult +sparse_tensor::detail::stageWithSortImpl(StageWithSortSparseOp op, + PatternRewriter &rewriter) { + if (!op.needsExtraSort()) + return failure(); + + Location loc = op.getLoc(); + Type finalTp = op->getOpResult(0).getType(); + SparseTensorType dstStt(finalTp.cast()); + + Type srcCOOTp = getCOOFromTypeWithOrdering( + dstStt.getRankedTensorType(), dstStt.getDimToLvl(), /*ordered=*/false); + + // Clones the original operation but changing the output to an unordered COO. + Operation *cloned = rewriter.clone(*op.getOperation()); + rewriter.updateRootInPlace(cloned, [cloned, srcCOOTp]() { + cloned->getOpResult(0).setType(srcCOOTp); + }); + Value srcCOO = cloned->getOpResult(0); + + // -> sort + Type dstCOOTp = getCOOFromTypeWithOrdering( + dstStt.getRankedTensorType(), dstStt.getDimToLvl(), /*ordered=*/true); + Value dstCOO = rewriter.create( + loc, dstCOOTp, srcCOO, SparseTensorSortKind::HybridQuickSort); + + // -> dest. + if (dstCOO.getType() == finalTp) { + rewriter.replaceOp(op, dstCOO); + } else { + // Need an extra conversion if the target type is not COO. + rewriter.replaceOpWithNewOp(op, finalTp, dstCOO); + } + // TODO: deallocate extra COOs, we should probably delegate it to buffer + // deallocation pass. + return success(); +} diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp index a1ab2495f5f7b..1bfee3aa1d7ee 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -829,10 +829,56 @@ struct ReshapeRewriter : public OpRewritePattern { } }; +struct TensorLike { + TensorLike(OpBuilder &builder, Location loc, RankedTensorType rtt, + ValueRange sizes) + : isSparse(rtt.getEncoding() != nullptr) { + SmallVector dynSzs; + getDynamicSizes(rtt, sizes, dynSzs); + + if (isSparse) + val = builder.create(loc, rtt, dynSzs); + else + val = allocDenseTensor(builder, loc, rtt, sizes); + }; + + void insertOrStore(OpBuilder &builder, Location loc, Value v, + ValueRange crds) { + if (isSparse) + val = builder.create(loc, v, val, crds); + else + builder.create(loc, v, val, crds); + } + + Value getSSA() const { + // We don't need to maintain the SSA chain for a memref value. + return isSparse ? val : nullptr; + } + + Value finalize(OpBuilder &builder, Location loc, RankedTensorType rtp) const { + if (isSparse) + return builder.create(loc, val, true); + return builder.create(loc, rtp, val); + } + + void updateSSA(Value v) { + // Dense memref is a non-SSA value. + assert(isSparse); + val = v; + } + +private: + bool isSparse; + Value val; // either a memref (for dense tensor) or a sparse tensor. +}; + struct ConcatenateRewriter : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ConcatenateOp op, PatternRewriter &rewriter) const override { + if (op.needsExtraSort()) + op.emitError("ConcatenateOp not staged"); + const Location loc = op.getLoc(); const auto dstTp = getSparseTensorType(op); const Dimension dimRank = dstTp.getDimRank(); @@ -852,94 +898,54 @@ struct ConcatenateRewriter : public OpRewritePattern { // foreach in %s1 : insert d0, d1, %tmp // foreach in %s2 : insert d0, d1 + size(s1), %tmp // foreach in %s3 : insert d0, d1 + size(s1) + size(s2), %tmp - // %t = convert_to_dest_tensor(%tmp) - // - // NOTE: this cannot be `const` because it will be changed when - // `needTmpCOO`, but that's buried in the conditional below and - // thus not easily extracted. - auto encDst = dstTp.getEncoding(); - Value dst; // Destination tensor for inserting source tensor values. - bool needTmpCOO = true; - const bool allDense = dstTp.hasEncoding() && dstTp.isAllDense(); - Value annotatedDenseDst; - if (dstTp.hasEncoding()) { - bool allOrdered = false; - // When concatenating on dimension 0, and all inputs are sorted - // and have an identity dimToLvl, the concatenate will generate - // coords in lexOrder thus no need for the tmp COO buffer. - // TODO: When conDim != 0, as long as conDim is the first dimension - // in all input/output buffers, and all input/output buffers have the same - // dimToLvl, the tmp COO buffer is still unnecessary (e.g, concatenate - // CSC matrices along column). - if (!allDense && conDim == 0 && dstTp.isIdentity()) { - for (auto i : op.getInputs()) { - const auto stt = getSparseTensorType(i); - allOrdered = stt.isAllOrdered() && stt.isIdentity(); - if (!allOrdered) - break; - } - } - - needTmpCOO = !allDense && !allOrdered; - const RankedTensorType tp = getBufferType(dstTp, needTmpCOO); - encDst = needTmpCOO ? getSparseTensorEncoding(tp) : encDst; - SmallVector dynSizes; - getDynamicSizes(dstTp, sizes, dynSizes); - dst = rewriter.create(loc, tp, dynSizes).getResult(); - if (allDense) { - // Create a view of the values buffer to match the unannotated dense - // tensor. - Value valuesBuffer = genToValues(rewriter, loc, dst); - Value dimCoords = - genAlloca(rewriter, loc, dimRank, rewriter.getIndexType(), - /*staticShape=*/true); - annotatedDenseDst = dst; - dst = reshapeValuesToLevels(rewriter, loc, encDst, sizes, valuesBuffer, - dimCoords); - } - } else { - // TODO: Dense buffers should be allocated/deallocated via the callback - // in BufferizationOptions. - dst = allocDenseTensor(rewriter, loc, dstTp, sizes); - } + TensorLike dstBuf(rewriter, loc, dstTp.getRankedTensorType(), sizes); Value offset = constantIndex(rewriter, loc, 0); - SmallVector initArgs; - if (encDst && !allDense) - initArgs.push_back(dst); + Value iterArg = dstBuf.getSSA(); + ForeachOp foreachOp; for (Value input : op.getInputs()) { - // Build a for op for each input tensor to append new values into the + // Builds a for op for each input tensor to append new values into the // output tensor. foreachOp = rewriter.create( - loc, input, initArgs, + loc, input, iterArg ? ValueRange{iterArg} : ValueRange{}, [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v, ValueRange reduc) { SmallVector dstLcvs(dstTp.getLvlRank()); for (Dimension d = 0; d < dimRank; d++) { Value crd = dcvs[d]; + // Transforms coordinates for the concatenating dim. if (d == conDim) - // Transform coordinates for the concatenating dim. crd = builder.create(loc, crd, offset); // FIXME: `toStoredDim` is deprecated - dstLcvs[toStoredDim(encDst, d)] = crd; + dstLcvs[toStoredDim(dstTp.getEncoding(), d)] = crd; } - if (encDst && !allDense) { - Value cond = genIsNonzero(rewriter, loc, v); - scf::IfOp ifOp = builder.create( - loc, TypeRange(reduc.front().getType()), cond, /*else*/ true); + + if (!reduc.empty()) + dstBuf.updateSSA(reduc.front()); + + if (!dstTp.isAllDense()) { + Value cond = genIsNonzero(builder, loc, v); + auto ifOp = builder.create(loc, reduc.getTypes(), cond, + /*else*/ true); + builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); + builder.create(loc, dstBuf.getSSA()); + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); - Value t = - builder.create(loc, v, reduc.front(), dstLcvs); - rewriter.create(loc, t); - rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); - rewriter.create(loc, reduc.front()); - rewriter.setInsertionPointAfter(ifOp); - rewriter.create(loc, ifOp.getResult(0)); + dstBuf.insertOrStore(builder, loc, v, dstLcvs); + builder.create(loc, dstBuf.getSSA()); + + // Exits the ifOp, update the sparse tensor SSA value. + builder.setInsertionPointAfter(ifOp); + assert(!reduc.empty()); + dstBuf.updateSSA(ifOp.getResult(0)); } else { - builder.create(loc, v, dst, dstLcvs); - builder.create(loc); + dstBuf.insertOrStore(builder, loc, v, dstLcvs); } + if (reduc.empty()) + builder.create(loc); + else + builder.create(loc, dstBuf.getSSA()); }); // Accumulates the offset. Note that only static-shaped inputs are allowed // by concatenate op verifier, which saves us from computing the offset @@ -948,88 +954,27 @@ struct ConcatenateRewriter : public OpRewritePattern { assert(sh.has_value()); offset = rewriter.create( loc, offset, constantIndex(rewriter, loc, *sh)); - if (encDst && !allDense) { - dst = foreachOp.getResult(0); - initArgs[0] = dst; - } - } - // Temp variable to avoid needing to call `getRankedTensorType` - // in the three use-sites below. - const RankedTensorType dstRTT = dstTp; - if (!encDst) { - rewriter.replaceOpWithNewOp(op, dstRTT, dst); - } else if (allDense) { - rewriter.replaceOp( - op, rewriter.create(loc, dstRTT, annotatedDenseDst) - .getResult()); - } else { - dst = rewriter.create(loc, dst, true); - if (needTmpCOO) { - Value tmpCoo = dst; - Type dstCooTp = getCOOType(dstRTT, true); - // TODO: this should be a sort_coo operation. - dst = rewriter - .create(loc, dstCooTp, tmpCoo, - SparseTensorSortKind::HybridQuickSort) - .getResult(); - dst = rewriter.create(loc, dstRTT, dst).getResult(); - rewriter.create(loc, tmpCoo); + if (!foreachOp.getResults().empty()) { + iterArg = foreachOp.getResult(0); + dstBuf.updateSSA(iterArg); } - rewriter.replaceOp(op, dst); } - return success(); - } -}; -struct TensorLike { - TensorLike(OpBuilder &builder, Location loc, RankedTensorType rtt, - ValueRange sizes) - : isSparse(rtt.getEncoding() != nullptr) { - SmallVector dynSzs; - getDynamicSizes(rtt, sizes, dynSzs); - - if (isSparse) - val = builder.create(loc, rtt, dynSzs); - else - val = allocDenseTensor(builder, loc, rtt, sizes); - }; - - void insertOrStore(OpBuilder &builder, Location loc, Value v, - ValueRange crds) { - if (isSparse) - val = builder.create(loc, v, val, crds); - else - builder.create(loc, v, val, crds); - } - - Value getSSA() const { - // We don't need to maintain the SSA chain for a memref value. - return isSparse ? val : nullptr; - } - - Value finalize(OpBuilder &builder, Location loc, RankedTensorType rtp) const { - if (isSparse) - return builder.create(loc, val, true); - return builder.create(loc, rtp, val); - } + if (!foreachOp.getResults().empty()) + dstBuf.updateSSA(iterArg); - void updateSSA(Value v) { - // Dense memref is a non-SSA value. - assert(isSparse); - val = v; + Value ret = dstBuf.finalize(rewriter, loc, dstTp.getRankedTensorType()); + rewriter.replaceOp(op, ret); + return success(); } - -private: - bool isSparse; - Value val; // either a memref (for dense tensor) or a sparse tensor. }; struct DirectConvertRewriter : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ConvertOp op, PatternRewriter &rewriter) const override { - if (!op.directConvertable()) + if (op.needsExtraSort()) return op.emitError("ConvertOp not staged."); // TODO: Maybe we want a different operation for this too. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp index 4c163ea6e067b..5875cd4f9fd9d 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp @@ -15,56 +15,19 @@ using namespace mlir::sparse_tensor; namespace { -struct StageUnorderedConvert : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +template +struct StageUnorderedSparseOps : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(ConvertOp op, + LogicalResult matchAndRewrite(StageWithSortOp op, PatternRewriter &rewriter) const override { - // TODO: Implement it as an Interface, this can be reused from other - // operations too (e.g., concatenate, reshape, etc). - if (op.directConvertable()) - return failure(); - - Location loc = op.getLoc(); - SparseTensorType srcStt = getSparseTensorType(op.getSource()); - SparseTensorType dstStt = getSparseTensorType(op.getDest()); - - // Just to make sure that convert to dense tensor is always direct. - assert(!dstStt.isAllDense()); - - // source -> coo - // The tmp COO must be unordered, otherwise it is a direct conversion. - assert(!(srcStt.hasSameDimToLvl(dstStt) && srcStt.isAllOrdered())); - (void)srcStt; // to silence warning when assertion is disabled - - Type srcCOOTp = getCOOFromTypeWithOrdering( - dstStt.getRankedTensorType(), dstStt.getDimToLvl(), /*ordered=*/false); - Value srcCOO = op.getSource(); - if (srcCOO.getType() != srcCOOTp) - srcCOO = rewriter.create(loc, srcCOOTp, op.getSource()); - - // -> sort - Type dstCOOTp = getCOOFromTypeWithOrdering( - dstStt.getRankedTensorType(), dstStt.getDimToLvl(), /*ordered=*/true); - Value dstCOO = rewriter.create( - loc, dstCOOTp, srcCOO, SparseTensorSortKind::HybridQuickSort); - - // -> dest. - if (dstCOO.getType() == op.getType()) { - rewriter.replaceOp(op, dstCOO); - } else { - // Need an extra conversion if the target type is not COO. - rewriter.replaceOpWithNewOp(op, op.getDest().getType(), - dstCOO); - } - // TODO: deallocate extra COOs, we should probably delegate it to buffer - // deallocation pass. - - return success(); + return llvm::cast(op.getOperation()) + .stageWithSort(rewriter); } }; } // namespace void mlir::populateStageSparseOperationsPatterns(RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); + patterns.add, + StageUnorderedSparseOps>(patterns.getContext()); } diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index 63f9cdafce88b..30a8ee557e365 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -2683,6 +2683,7 @@ td_library( srcs = [ "include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td", "include/mlir/Dialect/SparseTensor/IR/SparseTensorBase.td", + "include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td", "include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td", "include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td", ], @@ -2694,6 +2695,15 @@ td_library( ], ) +td_library( + name = "SparseTensorInterfacesTdFiles", + srcs = [ + "include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td", + ], + includes = ["include"], + deps = [":OpBaseTdFiles"], +) + gentbl_cc_library( name = "SparseTensorAttrDefsIncGen", tbl_outs = [ @@ -2801,6 +2811,23 @@ gentbl_cc_library( deps = [":PassBaseTdFiles"], ) +gentbl_cc_library( + name = "SparseTensorInterfacesIncGen", + tbl_outs = [ + ( + ["-gen-op-interface-decls"], + "include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h.inc", + ), + ( + ["-gen-op-interface-defs"], + "include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td", + deps = [":SparseTensorInterfacesTdFiles"], +) + # This library is shared by both SparseTensorDialect and # SparseTensorRuntime, so it must not depend on any of the MLIR/LLVM # internals or else mlir_c_runner_utils will inherit that dependency. @@ -2823,9 +2850,11 @@ cc_library( "lib/Dialect/SparseTensor/IR/Detail/Var.cpp", "lib/Dialect/SparseTensor/IR/Detail/Var.h", "lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp", + "lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp", ], hdrs = [ "include/mlir/Dialect/SparseTensor/IR/SparseTensor.h", + "include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h", "include/mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h", "include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h", ], @@ -2837,6 +2866,7 @@ cc_library( ":InferTypeOpInterface", ":SparseTensorAttrDefsIncGen", ":SparseTensorEnums", + ":SparseTensorInterfacesIncGen", ":SparseTensorOpsIncGen", ":SparseTensorTypesIncGen", "//llvm:Support",