Skip to content

[mlir][VectorOps] Extend vector.constant_mask to support 'all true' scalable dims #66638

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2248,7 +2248,10 @@ def Vector_ConstantMaskOp :
define a hyper-rectangular region within which elements values are set to 1
(otherwise element values are set to 0). Each value of 'mask_dim_sizes' must
be non-negative and not greater than the size of the corresponding vector
dimension (as opposed to vector.create_mask which allows this).
dimension (as opposed to vector.create_mask which allows this). Sizes that
correspond to scalable dimensions are implicitly multiplied by vscale,
though currently only zero (none set) or the size of the dim/vscale
(all set) are supported.

Example:

Expand Down
21 changes: 9 additions & 12 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5320,13 +5320,18 @@ LogicalResult ConstantMaskOp::verify() {
// Verify that each array attr element is in bounds of corresponding vector
// result dimension size.
auto resultShape = resultType.getShape();
auto resultScalableDims = resultType.getScalableDims();
SmallVector<int64_t, 4> maskDimSizes;
for (const auto &it : llvm::enumerate(getMaskDimSizes())) {
int64_t attrValue = llvm::cast<IntegerAttr>(it.value()).getInt();
if (attrValue < 0 || attrValue > resultShape[it.index()])
for (const auto [index, intAttr] : llvm::enumerate(getMaskDimSizes())) {
int64_t maskDimSize = llvm::cast<IntegerAttr>(intAttr).getInt();
if (maskDimSize < 0 || maskDimSize > resultShape[index])
return emitOpError(
"array attr of size out of bounds of vector result dimension size");
maskDimSizes.push_back(attrValue);
if (resultScalableDims[index] && maskDimSize != 0 &&
maskDimSize != resultShape[index])
return emitOpError(
"only supports 'none set' or 'all set' scalable dimensions");
maskDimSizes.push_back(maskDimSize);
}
// Verify that if one mask dim size is zero, they all should be zero (because
// the mask region is a conjunction of each mask dimension interval).
Expand All @@ -5335,14 +5340,6 @@ LogicalResult ConstantMaskOp::verify() {
if (anyZeros && !allZeros)
return emitOpError("expected all mask dim sizes to be zeros, "
"as a result of conjunction with zero mask dim");
// Verify that if the mask type is scalable, dimensions should be zero because
// constant scalable masks can only be defined for the "none set" or "all set"
// cases, and there is no VLA way to define an "all set" case for
// `vector.constant_mask`. In the future, a convention could be established
// to decide if a specific dimension value could be considered as "all set".
if (resultType.isScalable() &&
llvm::cast<IntegerAttr>(getMaskDimSizes()[0]).getInt() != 0)
return emitOpError("expected mask dim sizes for scalable masks to be 0");
return success();
}

Expand Down
53 changes: 26 additions & 27 deletions mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto dstType = op.getType();
auto eltType = dstType.getElementType();
auto dimSizes = op.getMaskDimSizes();
int64_t rank = dstType.getRank();

Expand All @@ -115,43 +114,43 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
bool value = cast<IntegerAttr>(dimSizes[0]).getInt() == 1;
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, dstType,
DenseIntElementsAttr::get(
VectorType::get(ArrayRef<int64_t>{}, rewriter.getI1Type()),
ArrayRef<bool>{value}));
DenseIntElementsAttr::get(VectorType::get({}, rewriter.getI1Type()),
value));
return success();
}

// Scalable constant masks can only be lowered for the "none set" case.
if (cast<VectorType>(dstType).isScalable()) {
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, DenseElementsAttr::get(dstType, false));
return success();
}

int64_t trueDim = std::min(dstType.getDimSize(0),
cast<IntegerAttr>(dimSizes[0]).getInt());
int64_t trueDimSize = cast<IntegerAttr>(dimSizes[0]).getInt();

if (rank == 1) {
// Express constant 1-D case in explicit vector form:
// [T,..,T,F,..,F].
SmallVector<bool> values(dstType.getDimSize(0));
for (int64_t d = 0; d < trueDim; d++)
values[d] = true;
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, dstType, rewriter.getBoolVectorAttr(values));
if (trueDimSize == 0 || trueDimSize == dstType.getDimSize(0)) {
// Use constant splat for 'all set' or 'none set' dims.
// This produces correct code for scalable dimensions (it will lower to
// a constant splat).
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, DenseElementsAttr::get(dstType, trueDimSize != 0));
} else {
// Express constant 1-D case in explicit vector form:
// [T,..,T,F,..,F].
// Note: The verifier would reject this case for scalable vectors.
SmallVector<bool> values(dstType.getDimSize(0), false);
for (int64_t d = 0; d < trueDimSize; d++)
values[d] = true;
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, dstType, rewriter.getBoolVectorAttr(values));
}
return success();
}

VectorType lowType =
VectorType::get(dstType.getShape().drop_front(), eltType);
SmallVector<int64_t> newDimSizes;
for (int64_t r = 1; r < rank; r++)
newDimSizes.push_back(cast<IntegerAttr>(dimSizes[r]).getInt());
if (dstType.getScalableDims().front())
return rewriter.notifyMatchFailure(
op, "Cannot unroll leading scalable dim in dstType");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be able to test for this in invalid.mlir, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this is a match failure not an invalid op (this is just a fancy return failure()), it does not produce a diagnostic.


VectorType lowType = VectorType::Builder(dstType).dropDim(0);
Value trueVal = rewriter.create<vector::ConstantMaskOp>(
loc, lowType, rewriter.getI64ArrayAttr(newDimSizes));
loc, lowType, rewriter.getArrayAttr(dimSizes.getValue().drop_front()));
Value result = rewriter.create<arith::ConstantOp>(
loc, dstType, rewriter.getZeroAttr(dstType));
for (int64_t d = 0; d < trueDim; d++)
for (int64_t d = 0; d < trueDimSize; d++)
result =
rewriter.create<vector::InsertOp>(loc, dstType, trueVal, result, d);
rewriter.replaceOp(op, result);
Expand Down
41 changes: 39 additions & 2 deletions mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1819,16 +1819,53 @@ func.func @genbool_1d() -> vector<8xi1> {

// -----

func.func @genbool_1d_scalable() -> vector<[8]xi1> {
func.func @genbool_1d_scalable_all_false() -> vector<[8]xi1> {
%0 = vector.constant_mask [0] : vector<[8]xi1>
return %0 : vector<[8]xi1>
}
// CHECK-LABEL: func @genbool_1d_scalable
// CHECK-LABEL: func @genbool_1d_scalable_all_false
// CHECK: %[[VAL_0:.*]] = arith.constant dense<false> : vector<[8]xi1>
// CHECK: return %[[VAL_0]] : vector<[8]xi1>

// -----

func.func @genbool_1d_scalable_all_true() -> vector<[8]xi1> {
%0 = vector.constant_mask [8] : vector<[8]xi1>
return %0 : vector<[8]xi1>
}
// CHECK-LABEL: func @genbool_1d_scalable_all_true
// CHECK: %[[VAL_0:.*]] = arith.constant dense<true> : vector<[8]xi1>
// CHECK: return %[[VAL_0]] : vector<[8]xi1>

// -----

func.func @genbool_2d_trailing_scalable() -> vector<4x[4]xi1> {
%0 = vector.constant_mask [2, 4] : vector<4x[4]xi1>
return %0 : vector<4x[4]xi1>
}
// CHECK-LABEL: func.func @genbool_2d_trailing_scalable
// CHECK: %[[VAL_0:.*]] = arith.constant dense<true> : vector<[4]xi1>
// CHECK: %[[VAL_1:.*]] = arith.constant dense<false> : vector<4x[4]xi1>
// CHECK: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[VAL_1]] : vector<4x[4]xi1> to !llvm.array<4 x vector<[4]xi1>>
// CHECK: %[[VAL_3:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_2]][0] : !llvm.array<4 x vector<[4]xi1>>
// CHECK: %[[VAL_4:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_3]][1] : !llvm.array<4 x vector<[4]xi1>>
// CHECK: %[[VAL_5:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : !llvm.array<4 x vector<[4]xi1>> to vector<4x[4]xi1>
// CHECK: return %[[VAL_5]] : vector<4x[4]xi1>

// -----

/// Currently, this is not supported as generating the mask would require
/// unrolling the leading scalable dimension at compile time.
func.func @cannot_genbool_2d_leading_scalable() -> vector<[4]x4xi1> {
%0 = vector.constant_mask [4, 2] : vector<[4]x4xi1>
return %0 : vector<[4]x4xi1>
}
// CHECK-LABEL: func.func @cannot_genbool_2d_leading_scalable
// CHECK: %[[VAL_0:.*]] = vector.constant_mask [4, 2] : vector<[4]x4xi1>
// CHECK: return %[[VAL_0]] : vector<[4]x4xi1>

// -----

func.func @genbool_2d() -> vector<4x4xi1> {
%v = vector.constant_mask [2, 2] : vector<4x4xi1>
return %v: vector<4x4xi1>
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/Vector/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -995,7 +995,7 @@ func.func @constant_mask_with_zero_mask_dim_size() {
// -----

func.func @constant_mask_scalable_non_zero_dim_size() {
// expected-error@+1 {{expected mask dim sizes for scalable masks to be 0}}
// expected-error@+1 {{only supports 'none set' or 'all set' scalable dimensions}}
%0 = vector.constant_mask [2] : vector<[8]xi1>
}

Expand Down
4 changes: 4 additions & 0 deletions mlir/test/Dialect/Vector/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,10 @@ func.func @constant_vector_mask() {
%0 = vector.constant_mask [3, 2] : vector<4x3xi1>
// CHECK: vector.constant_mask [0] : vector<[4]xi1>
%1 = vector.constant_mask [0] : vector<[4]xi1>
// CHECK: vector.constant_mask [4] : vector<[4]xi1>
%2 = vector.constant_mask [4] : vector<[4]xi1>
// CHECK: vector.constant_mask [1, 4] : vector<2x[4]xi1>
%3 = vector.constant_mask [1, 4] : vector<2x[4]xi1>
return
}

Expand Down