-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[mlir][vector] add tensor.concat, bitcast, expand_shape, collapse_shape vectorization support #97297
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[mlir][vector] add tensor.concat, bitcast, expand_shape, collapse_shape vectorization support #97297
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -1718,6 +1718,209 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp, | |||||
| return success(); | ||||||
| } | ||||||
|
|
||||||
| /// Vectorize a `tensor::expandshape` to these 3 Ops: | ||||||
| /// Vector::TransferReadOp - Reads a vector from the source tensor | ||||||
| /// ShapeCastOp - Reshape the data based on the target. | ||||||
| /// vector::TransferWriteOp. - Write the result vector back to the destination | ||||||
| /// tensor | ||||||
| static LogicalResult lowerTensorReshape(RewriterBase &rewriter, | ||||||
| Operation *inputOp, | ||||||
| ArrayRef<int64_t> inputVectorSizes, | ||||||
| SmallVectorImpl<Value> &newResults) { | ||||||
| OpBuilder::InsertionGuard g(rewriter); | ||||||
| rewriter.setInsertionPoint(inputOp); | ||||||
| auto src = inputOp->getOperand(0); | ||||||
| auto srcType = mlir::dyn_cast<ShapedType>(src.getType()); | ||||||
| auto result = inputOp->getResults()[0]; | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In MLIR, we generally dislike "raw" indexed accessors. Can this function be turned into a function template and use named accessors instead? |
||||||
| auto resultType = mlir::dyn_cast<ShapedType>(result.getType()); | ||||||
| ArrayRef<int64_t> resultShape = resultType.getShape(); | ||||||
| ArrayRef<int64_t> srcShape = srcType.getShape(); | ||||||
|
Comment on lines
+1736
to
+1737
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This most likely needs at least an assertion that shapes are static. The logic below will become wild if given a |
||||||
| Location loc = inputOp->getLoc(); | ||||||
|
|
||||||
| llvm::SmallVector<int64_t> srcVectorizedShape; | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need to prefix |
||||||
| llvm::SmallDenseMap<int64_t, int64_t> shapeScales; | ||||||
|
|
||||||
| auto getVectorizeShape = [&](ArrayRef<int64_t> &retShape, | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||
| ArrayRef<int64_t> &inputShape) { | ||||||
| bool isResultShapeBigger = srcType.getRank() < resultType.getRank(); | ||||||
|
|
||||||
| int64_t cur = 1, resultIdx = 0; | ||||||
| for (auto [srcIdx, ss] : llvm::enumerate(inputShape)) { | ||||||
| cur *= ss; | ||||||
| if (!isResultShapeBigger) { | ||||||
| // collapse | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please use full sentences in comments, including capitalization and trailing full stops. |
||||||
| srcVectorizedShape.emplace_back(ss); | ||||||
| if (cur == retShape[resultIdx]) { | ||||||
| if (shapeScales.count(resultIdx)) { | ||||||
| srcVectorizedShape.back() *= shapeScales[resultIdx]; | ||||||
| } | ||||||
| cur = 1; | ||||||
| resultIdx++; | ||||||
| } | ||||||
|
Comment on lines
+1753
to
+1759
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
| } else { | ||||||
| // expand | ||||||
| if (cur == retShape[resultIdx]) { | ||||||
| srcVectorizedShape.emplace_back(cur); | ||||||
| if (shapeScales.count(srcIdx)) { | ||||||
| srcVectorizedShape.back() *= shapeScales[srcIdx]; | ||||||
| } | ||||||
| cur = 1; | ||||||
| resultIdx++; | ||||||
| } | ||||||
| } | ||||||
| } | ||||||
| }; | ||||||
| if (!inputVectorSizes.empty()) { | ||||||
| for (auto [idx, vs] : llvm::enumerate(inputVectorSizes)) { | ||||||
| if (vs != resultShape[idx]) | ||||||
| shapeScales[idx] = vs / resultShape[idx]; | ||||||
| } | ||||||
|
|
||||||
| bool isResultShapeBigger = srcType.getRank() < resultType.getRank(); | ||||||
| if (!isResultShapeBigger) { | ||||||
| getVectorizeShape(resultShape, srcShape); | ||||||
| } else { | ||||||
| getVectorizeShape(srcShape, resultShape); | ||||||
| } | ||||||
| } else { | ||||||
| srcVectorizedShape.assign(srcShape.begin(), srcShape.end()); | ||||||
| } | ||||||
| // read | ||||||
| auto padValue = rewriter.create<arith::ConstantOp>( | ||||||
| loc, rewriter.getZeroAttr(srcType.getElementType())); | ||||||
| Value readResult = vector::createReadOrMaskedRead( | ||||||
| rewriter, loc, src, | ||||||
| inputVectorSizes.empty() ? srcType.getShape() : srcVectorizedShape, | ||||||
| padValue, false); | ||||||
|
|
||||||
| auto shapeCastType = | ||||||
| VectorType::get(inputVectorSizes.empty() ? resultShape : inputVectorSizes, | ||||||
| resultType.getElementType()); | ||||||
| vector::ShapeCastOp shapeCastOp = | ||||||
| rewriter.create<vector::ShapeCastOp>(loc, shapeCastType, readResult); | ||||||
|
|
||||||
| // write | ||||||
| SmallVector<OpFoldResult> destSizes; | ||||||
| for (auto size : resultShape) { | ||||||
| destSizes.emplace_back(rewriter.getIndexAttr(size)); | ||||||
| } | ||||||
|
Comment on lines
+1803
to
+1806
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reserve space before appending in a loop. Or better, use a proper combinator like |
||||||
| Operation *write = createWriteOrMaskedWrite( | ||||||
| rewriter, loc, shapeCastOp->getResults()[0], destSizes, | ||||||
| inputVectorSizes.empty() ? resultShape : inputVectorSizes, false); | ||||||
| newResults.push_back(write->getResult(0)); | ||||||
| return success(); | ||||||
| } | ||||||
|
|
||||||
| /// Vectorize a `tensor::bitcast` to these 3 Ops: | ||||||
| /// vector::TransferReadOp - Reads a vector from the source tensor | ||||||
| /// vector.Bitcast - Bitcast the data based on the target. | ||||||
| /// vector::TransferWriteOp. - Write the result vector back to the destination | ||||||
| /// tensor | ||||||
| static LogicalResult lowerTensorBitcastOp(RewriterBase &rewriter, | ||||||
| tensor::BitcastOp bitCastOp, | ||||||
| ArrayRef<int64_t> inputVectorSizes, | ||||||
| SmallVectorImpl<Value> &newResults) { | ||||||
| OpBuilder::InsertionGuard g(rewriter); | ||||||
| rewriter.setInsertionPoint(bitCastOp); | ||||||
|
|
||||||
| auto sourceType = bitCastOp.getSource().getType(); | ||||||
| auto resultType = bitCastOp.getResult().getType(); | ||||||
| auto resultShape = resultType.getShape(); | ||||||
| if (inputVectorSizes.empty()) { | ||||||
| inputVectorSizes = resultShape; | ||||||
| } | ||||||
| Location loc = bitCastOp->getLoc(); | ||||||
|
|
||||||
| // read | ||||||
| auto padValue = rewriter.create<arith::ConstantOp>( | ||||||
| loc, rewriter.getZeroAttr(sourceType.getElementType())); | ||||||
| Value readResult = vector::createReadOrMaskedRead( | ||||||
| rewriter, loc, bitCastOp.getSource(), inputVectorSizes, padValue, false); | ||||||
|
|
||||||
| // bitcast | ||||||
| auto resultVectorType = | ||||||
| VectorType::get(inputVectorSizes, resultType.getElementType()); | ||||||
| vector::BitCastOp vectorbitCastOp = | ||||||
| rewriter.create<vector::BitCastOp>(loc, resultVectorType, readResult); | ||||||
|
|
||||||
| // write | ||||||
| llvm::SmallVector<OpFoldResult> destSizes; | ||||||
| for (auto size : resultShape) | ||||||
| destSizes.emplace_back(rewriter.getIndexAttr(size)); | ||||||
| auto write = | ||||||
| createWriteOrMaskedWrite(rewriter, loc, vectorbitCastOp->getResult(0), | ||||||
| destSizes, inputVectorSizes, false); | ||||||
| newResults.push_back(write->getResults()[0]); | ||||||
| return success(); | ||||||
| } | ||||||
|
|
||||||
| /// Vectorize a `tensor::concat` to these 3 Ops: | ||||||
| /// Tensor::EmptyOp - The result tensor. | ||||||
| /// Vector::TransferWriteOp - Write the result vector back to the destination | ||||||
| /// tensor. | ||||||
| /// Vector::TransferWriteOp - Write the result vector back to the destination | ||||||
| /// tensor. | ||||||
| static LogicalResult lowerTensorConcatOp(RewriterBase &rewriter, | ||||||
| tensor::ConcatOp concatOp, | ||||||
| ArrayRef<int64_t> inputVectorSizes, | ||||||
| SmallVectorImpl<Value> &newResults) { | ||||||
| OpBuilder::InsertionGuard g(rewriter); | ||||||
| rewriter.setInsertionPoint(concatOp); | ||||||
|
|
||||||
| Location loc = concatOp.getLoc(); | ||||||
| FailureOr<Value> dest = | ||||||
| tensor::getOrCreateDestination(rewriter, loc, concatOp->getResult(0)); | ||||||
| if (failed(dest)) | ||||||
| return failure(); | ||||||
|
|
||||||
| auto empty = dest->getDefiningOp<tensor::EmptyOp>(); | ||||||
| if (!empty) | ||||||
| return failure(); | ||||||
|
|
||||||
| // Compute the partial sums for the slice offsets. | ||||||
| auto dim = concatOp.getDim(); | ||||||
| Value dimValue = | ||||||
| rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(dim)); | ||||||
|
|
||||||
| int64_t rank = concatOp.getResultType().getRank(); | ||||||
| auto srcType = | ||||||
| mlir::dyn_cast<RankedTensorType>(concatOp->getResultTypes()[0]); | ||||||
| auto padValue = rewriter.create<arith::ConstantOp>( | ||||||
| loc, rewriter.getZeroAttr(srcType.getElementType())); | ||||||
|
|
||||||
| // Construct the chain of insert_slice ops into the destination. | ||||||
| Value result = *dest; | ||||||
| Value previous_offset = rewriter.create<arith::ConstantIndexOp>(loc, 0); | ||||||
| for (auto [idx, input] : llvm::enumerate(concatOp.getInputs())) { | ||||||
|
|
||||||
| SmallVector<OpFoldResult> sizes = | ||||||
| tensor::getMixedSizes(rewriter, loc, input); | ||||||
| SmallVector<int64_t> readMaskShape; | ||||||
| auto inputType = mlir::dyn_cast<RankedTensorType>(input.getType()); | ||||||
| auto sourceShape = inputType.getShape(); | ||||||
|
|
||||||
| readMaskShape.append(sourceShape.begin(), sourceShape.end()); | ||||||
| Value readResult = vector::createReadOrMaskedRead( | ||||||
| rewriter, loc, input, sourceShape, padValue, false); | ||||||
| Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); | ||||||
| SmallVector<Value> indices(rank, zero); | ||||||
| indices[dim] = previous_offset; | ||||||
| result = rewriter | ||||||
| .create<vector::TransferWriteOp>( | ||||||
| loc, readResult, result, indices, | ||||||
| rewriter.getMultiDimIdentityMap(rank)) | ||||||
| ->getResults()[0]; | ||||||
| if (idx != concatOp.getNumOperands() - 1) { | ||||||
| auto dimOp = rewriter.create<tensor::DimOp>(loc, input, dimValue); | ||||||
| previous_offset = | ||||||
| rewriter.create<arith::AddIOp>(loc, dimOp, previous_offset); | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| newResults.push_back(result); | ||||||
| return success(); | ||||||
| } | ||||||
|
|
||||||
| // TODO: probably need some extra checks for reduction followed by consumer | ||||||
| // ops that may not commute (e.g. linear reduction + non-linear instructions). | ||||||
| static LogicalResult reductionPreconditions(LinalgOp op) { | ||||||
|
|
@@ -1931,6 +2134,108 @@ vectorizePadOpPrecondition(tensor::PadOp padOp, | |||||
| return success(); | ||||||
| } | ||||||
|
|
||||||
| static LogicalResult | ||||||
| lowerExpandOpPrecondition(tensor::ExpandShapeOp expandOp, | ||||||
| ArrayRef<int64_t> inputVectorSizes) { | ||||||
| auto resultType = expandOp->getResultTypes()[0]; | ||||||
| auto resultShape = mlir::dyn_cast<ShapedType>(resultType); | ||||||
| // check reassociation | ||||||
| llvm::SmallVector<int64_t> associateIndices; | ||||||
| for (auto &attr : expandOp.getReassociation()) { | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If |
||||||
| for (auto &indice : mlir::dyn_cast<ArrayAttr>(attr)) { | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: the singular of "indices" is "index". |
||||||
| associateIndices.push_back(mlir::dyn_cast<IntegerAttr>(indice).getInt()); | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| if (llvm::any_of(associateIndices, | ||||||
| [](int64_t x) { return x == ShapedType::kDynamic; })) { | ||||||
| LDBG("Reassociation must be static: " << expandOp << "\n"); | ||||||
| return failure(); | ||||||
| } | ||||||
| // check input and output shape | ||||||
| if (!resultShape.hasStaticShape() || | ||||||
| !expandOp.getSrcType().hasStaticShape()) { | ||||||
| LDBG("Input and output shape must be static: " << expandOp << "\n"); | ||||||
| return failure(); | ||||||
| } | ||||||
| if (!inputVectorSizes.empty() && | ||||||
| failed(vector::isValidMaskedInputVector(resultShape.getShape(), | ||||||
| inputVectorSizes))) | ||||||
| return failure(); | ||||||
|
|
||||||
| return success(); | ||||||
| } | ||||||
|
|
||||||
| static LogicalResult | ||||||
| lowerBitcastOpPrecondition(tensor::BitcastOp bitCastOp, | ||||||
| ArrayRef<int64_t> inputVectorSizes) { | ||||||
| auto resultType = bitCastOp->getResultTypes()[0]; | ||||||
| auto resultShapeType = mlir::dyn_cast<ShapedType>(resultType); | ||||||
| auto srcType = bitCastOp.getSource().getType(); | ||||||
| auto srcShapeType = mlir::dyn_cast<ShapedType>(srcType); | ||||||
|
|
||||||
| bool isStaticInputOutput = | ||||||
| resultShapeType.hasStaticShape() && srcShapeType.hasStaticShape(); | ||||||
| if (!isStaticInputOutput) { | ||||||
| LDBG("Input and output shape must be static: " << bitCastOp << "\n"); | ||||||
| return failure(); | ||||||
| } | ||||||
|
|
||||||
| if (!inputVectorSizes.empty() && | ||||||
| failed(vector::isValidMaskedInputVector(resultShapeType.getShape(), | ||||||
| inputVectorSizes))) | ||||||
| return failure(); | ||||||
| return success(); | ||||||
| } | ||||||
|
|
||||||
| static LogicalResult | ||||||
| lowerCollapseShapeOpPrecondition(tensor::CollapseShapeOp collapseOp, | ||||||
| ArrayRef<int64_t> inputVectorSizes) { | ||||||
| auto resultType = collapseOp->getResultTypes()[0]; | ||||||
| auto resultShapeType = mlir::dyn_cast<ShapedType>(resultType); | ||||||
| auto srcShapeType = collapseOp.getSrcType(); | ||||||
|
|
||||||
| bool isStaticInputOutput = | ||||||
| resultShapeType.hasStaticShape() && srcShapeType.hasStaticShape(); | ||||||
| if (!isStaticInputOutput) { | ||||||
| LDBG("Input and output shape must be static: " << collapseOp << "\n"); | ||||||
| return failure(); | ||||||
| } | ||||||
|
|
||||||
| if (!inputVectorSizes.empty() && | ||||||
| failed(vector::isValidMaskedInputVector(resultShapeType.getShape(), | ||||||
| inputVectorSizes))) | ||||||
| return failure(); | ||||||
| return success(); | ||||||
| } | ||||||
|
|
||||||
| static LogicalResult | ||||||
| lowerConcatOpPrecondition(tensor::ConcatOp concatOp, | ||||||
| ArrayRef<int64_t> inputVectorSizes) { | ||||||
| if (!inputVectorSizes.empty()) { | ||||||
| LDBG("Concat operation do not support specify inputVectorSizes: " | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| << concatOp << "\n"); | ||||||
| } | ||||||
| for (auto x : concatOp->getOperands()) { | ||||||
| auto type = mlir::dyn_cast<ShapedType>(x.getType()); | ||||||
| if (!type) { | ||||||
| LDBG("Operation type error: " << concatOp << "\n"); | ||||||
| return failure(); | ||||||
| } | ||||||
|
Comment on lines
+2221
to
+2224
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can this ever happen? |
||||||
| if (!type.hasStaticShape()) { | ||||||
| LDBG("Type must be static: " << concatOp << "\n"); | ||||||
| return failure(); | ||||||
| } | ||||||
| } | ||||||
| auto dim = concatOp.getDim(); | ||||||
| if (dim >= (uint64_t)concatOp.getResultType().getRank()) { | ||||||
| LDBG("Invalid dim: " << concatOp << "\n"); | ||||||
| return failure(); | ||||||
| } | ||||||
|
|
||||||
| return success(); | ||||||
| } | ||||||
|
|
||||||
| /// Preconditions for scalable vectors. | ||||||
| static LogicalResult | ||||||
| vectorizeScalableVectorPrecondition(Operation *op, | ||||||
|
|
@@ -1976,6 +2281,19 @@ LogicalResult mlir::linalg::vectorizeOpPrecondition( | |||||
| .Case<tensor::UnPackOp>([&](auto unpackOp) { | ||||||
| return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes); | ||||||
| }) | ||||||
| .Case<tensor::ExpandShapeOp>([&](auto expandShapeOp) { | ||||||
| return lowerExpandOpPrecondition(expandShapeOp, inputVectorSizes); | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is a strange choice to call these |
||||||
| }) | ||||||
| .Case<tensor::CollapseShapeOp>([&](auto collapseShapeOp) { | ||||||
| return lowerCollapseShapeOpPrecondition(collapseShapeOp, | ||||||
| inputVectorSizes); | ||||||
| }) | ||||||
| .Case<tensor::BitcastOp>([&](auto bitCastOp) { | ||||||
| return lowerBitcastOpPrecondition(bitCastOp, inputVectorSizes); | ||||||
| }) | ||||||
| .Case<tensor::ConcatOp>([&](auto concatOp) { | ||||||
| return lowerConcatOpPrecondition(concatOp, inputVectorSizes); | ||||||
| }) | ||||||
| .Default([](auto) { return failure(); }); | ||||||
| } | ||||||
|
|
||||||
|
|
@@ -2075,6 +2393,22 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op, | |||||
| return vectorizeAsTensorUnpackOp(rewriter, unpackOp, | ||||||
| inputVectorSizes, results); | ||||||
| }) | ||||||
| .Case<tensor::ExpandShapeOp>([&](auto expandShapeOp) { | ||||||
| return lowerTensorReshape(rewriter, expandShapeOp, inputVectorSizes, | ||||||
| results); | ||||||
| }) | ||||||
| .Case<tensor::CollapseShapeOp>([&](auto collapseShapeOp) { | ||||||
| return lowerTensorReshape(rewriter, collapseShapeOp, | ||||||
| inputVectorSizes, results); | ||||||
| }) | ||||||
| .Case<tensor::BitcastOp>([&](auto bitCastOp) { | ||||||
| return lowerTensorBitcastOp(rewriter, bitCastOp, inputVectorSizes, | ||||||
| results); | ||||||
| }) | ||||||
| .Case<tensor::ConcatOp>([&](auto concatOp) { | ||||||
| return lowerTensorConcatOp(rewriter, concatOp, inputVectorSizes, | ||||||
| results); | ||||||
| }) | ||||||
| .Default([](auto) { return failure(); }); | ||||||
|
|
||||||
| if (failed(vectorizeResult)) { | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please expand
autounless the type is obvious from line-level context.