Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3388,8 +3388,9 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(

// TODO: Check that the correct number of vectorSizes was provided.
for (Operation *target : targets) {
if (!isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp, tensor::UnPackOp>(
target)) {
if (!isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp, tensor::UnPackOp,
tensor::BitcastOp, tensor::CollapseShapeOp, tensor::ExpandShapeOp,
tensor::ConcatOp>(target)) {
return mlir::emitSilenceableFailure(target->getLoc())
<< "Unsupported Op, cannot vectorize";
}
Expand Down
334 changes: 334 additions & 0 deletions mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1718,6 +1718,209 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
return success();
}

/// Vectorize a `tensor::expandshape` to these 3 Ops:
/// Vector::TransferReadOp - Reads a vector from the source tensor
/// ShapeCastOp - Reshape the data based on the target.
/// vector::TransferWriteOp. - Write the result vector back to the destination
/// tensor
static LogicalResult lowerTensorReshape(RewriterBase &rewriter,
Operation *inputOp,
ArrayRef<int64_t> inputVectorSizes,
SmallVectorImpl<Value> &newResults) {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(inputOp);
auto src = inputOp->getOperand(0);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please expand auto unless the type is obvious from line-level context.

auto srcType = mlir::dyn_cast<ShapedType>(src.getType());
auto result = inputOp->getResults()[0];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In MLIR, we generally dislike "raw" indexed accessors. Can this function be turned into a function template and use named accessors instead?

auto resultType = mlir::dyn_cast<ShapedType>(result.getType());
ArrayRef<int64_t> resultShape = resultType.getShape();
ArrayRef<int64_t> srcShape = srcType.getShape();
Comment on lines +1736 to +1737
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This most likely needs at least an assertion that shapes are static. The logic below will become wild if given a kDynamic.

Location loc = inputOp->getLoc();

llvm::SmallVector<int64_t> srcVectorizedShape;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to prefix SmallVector with llvm::.

llvm::SmallDenseMap<int64_t, int64_t> shapeScales;

auto getVectorizeShape = [&](ArrayRef<int64_t> &retShape,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ArrayRef is already a reference as the name indicates, there is no need to pass it by reference.

ArrayRef<int64_t> &inputShape) {
bool isResultShapeBigger = srcType.getRank() < resultType.getRank();

int64_t cur = 1, resultIdx = 0;
for (auto [srcIdx, ss] : llvm::enumerate(inputShape)) {
cur *= ss;
if (!isResultShapeBigger) {
// collapse
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use full sentences in comments, including capitalization and trailing full stops.

srcVectorizedShape.emplace_back(ss);
if (cur == retShape[resultIdx]) {
if (shapeScales.count(resultIdx)) {
srcVectorizedShape.back() *= shapeScales[resultIdx];
}
cur = 1;
resultIdx++;
}
Comment on lines +1753 to +1759
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

} else {
// expand
if (cur == retShape[resultIdx]) {
srcVectorizedShape.emplace_back(cur);
if (shapeScales.count(srcIdx)) {
srcVectorizedShape.back() *= shapeScales[srcIdx];
}
cur = 1;
resultIdx++;
}
}
}
};
if (!inputVectorSizes.empty()) {
for (auto [idx, vs] : llvm::enumerate(inputVectorSizes)) {
if (vs != resultShape[idx])
shapeScales[idx] = vs / resultShape[idx];
}

bool isResultShapeBigger = srcType.getRank() < resultType.getRank();
if (!isResultShapeBigger) {
getVectorizeShape(resultShape, srcShape);
} else {
getVectorizeShape(srcShape, resultShape);
}
} else {
srcVectorizedShape.assign(srcShape.begin(), srcShape.end());
}
// read
auto padValue = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(srcType.getElementType()));
Value readResult = vector::createReadOrMaskedRead(
rewriter, loc, src,
inputVectorSizes.empty() ? srcType.getShape() : srcVectorizedShape,
padValue, false);

auto shapeCastType =
VectorType::get(inputVectorSizes.empty() ? resultShape : inputVectorSizes,
resultType.getElementType());
vector::ShapeCastOp shapeCastOp =
rewriter.create<vector::ShapeCastOp>(loc, shapeCastType, readResult);

// write
SmallVector<OpFoldResult> destSizes;
for (auto size : resultShape) {
destSizes.emplace_back(rewriter.getIndexAttr(size));
}
Comment on lines +1803 to +1806
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reserve space before appending in a loop. Or better, use a proper combinator like map_to_vector.

Operation *write = createWriteOrMaskedWrite(
rewriter, loc, shapeCastOp->getResults()[0], destSizes,
inputVectorSizes.empty() ? resultShape : inputVectorSizes, false);
newResults.push_back(write->getResult(0));
return success();
}

/// Vectorize a `tensor::bitcast` to these 3 Ops:
/// vector::TransferReadOp - Reads a vector from the source tensor
/// vector.Bitcast - Bitcast the data based on the target.
/// vector::TransferWriteOp. - Write the result vector back to the destination
/// tensor
static LogicalResult lowerTensorBitcastOp(RewriterBase &rewriter,
tensor::BitcastOp bitCastOp,
ArrayRef<int64_t> inputVectorSizes,
SmallVectorImpl<Value> &newResults) {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(bitCastOp);

auto sourceType = bitCastOp.getSource().getType();
auto resultType = bitCastOp.getResult().getType();
auto resultShape = resultType.getShape();
if (inputVectorSizes.empty()) {
inputVectorSizes = resultShape;
}
Location loc = bitCastOp->getLoc();

// read
auto padValue = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(sourceType.getElementType()));
Value readResult = vector::createReadOrMaskedRead(
rewriter, loc, bitCastOp.getSource(), inputVectorSizes, padValue, false);

// bitcast
auto resultVectorType =
VectorType::get(inputVectorSizes, resultType.getElementType());
vector::BitCastOp vectorbitCastOp =
rewriter.create<vector::BitCastOp>(loc, resultVectorType, readResult);

// write
llvm::SmallVector<OpFoldResult> destSizes;
for (auto size : resultShape)
destSizes.emplace_back(rewriter.getIndexAttr(size));
auto write =
createWriteOrMaskedWrite(rewriter, loc, vectorbitCastOp->getResult(0),
destSizes, inputVectorSizes, false);
newResults.push_back(write->getResults()[0]);
return success();
}

/// Vectorize a `tensor::concat` to these 3 Ops:
/// Tensor::EmptyOp - The result tensor.
/// Vector::TransferWriteOp - Write the result vector back to the destination
/// tensor.
/// Vector::TransferWriteOp - Write the result vector back to the destination
/// tensor.
static LogicalResult lowerTensorConcatOp(RewriterBase &rewriter,
tensor::ConcatOp concatOp,
ArrayRef<int64_t> inputVectorSizes,
SmallVectorImpl<Value> &newResults) {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(concatOp);

Location loc = concatOp.getLoc();
FailureOr<Value> dest =
tensor::getOrCreateDestination(rewriter, loc, concatOp->getResult(0));
if (failed(dest))
return failure();

auto empty = dest->getDefiningOp<tensor::EmptyOp>();
if (!empty)
return failure();

// Compute the partial sums for the slice offsets.
auto dim = concatOp.getDim();
Value dimValue =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(dim));

int64_t rank = concatOp.getResultType().getRank();
auto srcType =
mlir::dyn_cast<RankedTensorType>(concatOp->getResultTypes()[0]);
auto padValue = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(srcType.getElementType()));

// Construct the chain of insert_slice ops into the destination.
Value result = *dest;
Value previous_offset = rewriter.create<arith::ConstantIndexOp>(loc, 0);
for (auto [idx, input] : llvm::enumerate(concatOp.getInputs())) {

SmallVector<OpFoldResult> sizes =
tensor::getMixedSizes(rewriter, loc, input);
SmallVector<int64_t> readMaskShape;
auto inputType = mlir::dyn_cast<RankedTensorType>(input.getType());
auto sourceShape = inputType.getShape();

readMaskShape.append(sourceShape.begin(), sourceShape.end());
Value readResult = vector::createReadOrMaskedRead(
rewriter, loc, input, sourceShape, padValue, false);
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
SmallVector<Value> indices(rank, zero);
indices[dim] = previous_offset;
result = rewriter
.create<vector::TransferWriteOp>(
loc, readResult, result, indices,
rewriter.getMultiDimIdentityMap(rank))
->getResults()[0];
if (idx != concatOp.getNumOperands() - 1) {
auto dimOp = rewriter.create<tensor::DimOp>(loc, input, dimValue);
previous_offset =
rewriter.create<arith::AddIOp>(loc, dimOp, previous_offset);
}
}

newResults.push_back(result);
return success();
}

// TODO: probably need some extra checks for reduction followed by consumer
// ops that may not commute (e.g. linear reduction + non-linear instructions).
static LogicalResult reductionPreconditions(LinalgOp op) {
Expand Down Expand Up @@ -1931,6 +2134,108 @@ vectorizePadOpPrecondition(tensor::PadOp padOp,
return success();
}

static LogicalResult
lowerExpandOpPrecondition(tensor::ExpandShapeOp expandOp,
ArrayRef<int64_t> inputVectorSizes) {
auto resultType = expandOp->getResultTypes()[0];
auto resultShape = mlir::dyn_cast<ShapedType>(resultType);
// check reassociation
llvm::SmallVector<int64_t> associateIndices;
for (auto &attr : expandOp.getReassociation()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If auto had been expanded, it would have been clear that a reference is not necessary here.

for (auto &indice : mlir::dyn_cast<ArrayAttr>(attr)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: the singular of "indices" is "index".

associateIndices.push_back(mlir::dyn_cast<IntegerAttr>(indice).getInt());
}
}

if (llvm::any_of(associateIndices,
[](int64_t x) { return x == ShapedType::kDynamic; })) {
LDBG("Reassociation must be static: " << expandOp << "\n");
return failure();
}
// check input and output shape
if (!resultShape.hasStaticShape() ||
!expandOp.getSrcType().hasStaticShape()) {
LDBG("Input and output shape must be static: " << expandOp << "\n");
return failure();
}
if (!inputVectorSizes.empty() &&
failed(vector::isValidMaskedInputVector(resultShape.getShape(),
inputVectorSizes)))
return failure();

return success();
}

static LogicalResult
lowerBitcastOpPrecondition(tensor::BitcastOp bitCastOp,
ArrayRef<int64_t> inputVectorSizes) {
auto resultType = bitCastOp->getResultTypes()[0];
auto resultShapeType = mlir::dyn_cast<ShapedType>(resultType);
auto srcType = bitCastOp.getSource().getType();
auto srcShapeType = mlir::dyn_cast<ShapedType>(srcType);

bool isStaticInputOutput =
resultShapeType.hasStaticShape() && srcShapeType.hasStaticShape();
if (!isStaticInputOutput) {
LDBG("Input and output shape must be static: " << bitCastOp << "\n");
return failure();
}

if (!inputVectorSizes.empty() &&
failed(vector::isValidMaskedInputVector(resultShapeType.getShape(),
inputVectorSizes)))
return failure();
return success();
}

static LogicalResult
lowerCollapseShapeOpPrecondition(tensor::CollapseShapeOp collapseOp,
ArrayRef<int64_t> inputVectorSizes) {
auto resultType = collapseOp->getResultTypes()[0];
auto resultShapeType = mlir::dyn_cast<ShapedType>(resultType);
auto srcShapeType = collapseOp.getSrcType();

bool isStaticInputOutput =
resultShapeType.hasStaticShape() && srcShapeType.hasStaticShape();
if (!isStaticInputOutput) {
LDBG("Input and output shape must be static: " << collapseOp << "\n");
return failure();
}

if (!inputVectorSizes.empty() &&
failed(vector::isValidMaskedInputVector(resultShapeType.getShape(),
inputVectorSizes)))
return failure();
return success();
}

static LogicalResult
lowerConcatOpPrecondition(tensor::ConcatOp concatOp,
ArrayRef<int64_t> inputVectorSizes) {
if (!inputVectorSizes.empty()) {
LDBG("Concat operation do not support specify inputVectorSizes: "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
LDBG("Concat operation do not support specify inputVectorSizes: "
LDBG("Concat operation does not support specify inputVectorSizes: "

<< concatOp << "\n");
}
for (auto x : concatOp->getOperands()) {
auto type = mlir::dyn_cast<ShapedType>(x.getType());
if (!type) {
LDBG("Operation type error: " << concatOp << "\n");
return failure();
}
Comment on lines +2221 to +2224
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this ever happen?

if (!type.hasStaticShape()) {
LDBG("Type must be static: " << concatOp << "\n");
return failure();
}
}
auto dim = concatOp.getDim();
if (dim >= (uint64_t)concatOp.getResultType().getRank()) {
LDBG("Invalid dim: " << concatOp << "\n");
return failure();
}

return success();
}

/// Preconditions for scalable vectors.
static LogicalResult
vectorizeScalableVectorPrecondition(Operation *op,
Expand Down Expand Up @@ -1976,6 +2281,19 @@ LogicalResult mlir::linalg::vectorizeOpPrecondition(
.Case<tensor::UnPackOp>([&](auto unpackOp) {
return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes);
})
.Case<tensor::ExpandShapeOp>([&](auto expandShapeOp) {
return lowerExpandOpPrecondition(expandShapeOp, inputVectorSizes);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a strange choice to call these lowerFoo when everything above is called vectorizeFoo.

})
.Case<tensor::CollapseShapeOp>([&](auto collapseShapeOp) {
return lowerCollapseShapeOpPrecondition(collapseShapeOp,
inputVectorSizes);
})
.Case<tensor::BitcastOp>([&](auto bitCastOp) {
return lowerBitcastOpPrecondition(bitCastOp, inputVectorSizes);
})
.Case<tensor::ConcatOp>([&](auto concatOp) {
return lowerConcatOpPrecondition(concatOp, inputVectorSizes);
})
.Default([](auto) { return failure(); });
}

Expand Down Expand Up @@ -2075,6 +2393,22 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
inputVectorSizes, results);
})
.Case<tensor::ExpandShapeOp>([&](auto expandShapeOp) {
return lowerTensorReshape(rewriter, expandShapeOp, inputVectorSizes,
results);
})
.Case<tensor::CollapseShapeOp>([&](auto collapseShapeOp) {
return lowerTensorReshape(rewriter, collapseShapeOp,
inputVectorSizes, results);
})
.Case<tensor::BitcastOp>([&](auto bitCastOp) {
return lowerTensorBitcastOp(rewriter, bitCastOp, inputVectorSizes,
results);
})
.Case<tensor::ConcatOp>([&](auto concatOp) {
return lowerTensorConcatOp(rewriter, concatOp, inputVectorSizes,
results);
})
.Default([](auto) { return failure(); });

if (failed(vectorizeResult)) {
Expand Down
Loading