diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 28b5864914f69..64fbd722a4f02 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -2248,7 +2248,10 @@ def Vector_ConstantMaskOp : define a hyper-rectangular region within which elements values are set to 1 (otherwise element values are set to 0). Each value of 'mask_dim_sizes' must be non-negative and not greater than the size of the corresponding vector - dimension (as opposed to vector.create_mask which allows this). + dimension (as opposed to vector.create_mask which allows this). Sizes that + correspond to scalable dimensions are implicitly multiplied by vscale, + though currently only zero (none set) or the size of the dim/vscale + (all set) are supported. Example: diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index a8ad05f7bc1ca..3c68cb26fb55a 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5320,13 +5320,18 @@ LogicalResult ConstantMaskOp::verify() { // Verify that each array attr element is in bounds of corresponding vector // result dimension size. auto resultShape = resultType.getShape(); + auto resultScalableDims = resultType.getScalableDims(); SmallVector maskDimSizes; - for (const auto &it : llvm::enumerate(getMaskDimSizes())) { - int64_t attrValue = llvm::cast(it.value()).getInt(); - if (attrValue < 0 || attrValue > resultShape[it.index()]) + for (const auto [index, intAttr] : llvm::enumerate(getMaskDimSizes())) { + int64_t maskDimSize = llvm::cast(intAttr).getInt(); + if (maskDimSize < 0 || maskDimSize > resultShape[index]) return emitOpError( "array attr of size out of bounds of vector result dimension size"); - maskDimSizes.push_back(attrValue); + if (resultScalableDims[index] && maskDimSize != 0 && + maskDimSize != resultShape[index]) + return emitOpError( + "only supports 'none set' or 'all set' scalable dimensions"); + maskDimSizes.push_back(maskDimSize); } // Verify that if one mask dim size is zero, they all should be zero (because // the mask region is a conjunction of each mask dimension interval). @@ -5335,14 +5340,6 @@ LogicalResult ConstantMaskOp::verify() { if (anyZeros && !allZeros) return emitOpError("expected all mask dim sizes to be zeros, " "as a result of conjunction with zero mask dim"); - // Verify that if the mask type is scalable, dimensions should be zero because - // constant scalable masks can only be defined for the "none set" or "all set" - // cases, and there is no VLA way to define an "all set" case for - // `vector.constant_mask`. In the future, a convention could be established - // to decide if a specific dimension value could be considered as "all set". - if (resultType.isScalable() && - llvm::cast(getMaskDimSizes()[0]).getInt() != 0) - return emitOpError("expected mask dim sizes for scalable masks to be 0"); return success(); } diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp index 9a828ec0b845e..95b5ea011c825 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp @@ -105,7 +105,6 @@ class ConstantMaskOpLowering : public OpRewritePattern { PatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto dstType = op.getType(); - auto eltType = dstType.getElementType(); auto dimSizes = op.getMaskDimSizes(); int64_t rank = dstType.getRank(); @@ -115,43 +114,43 @@ class ConstantMaskOpLowering : public OpRewritePattern { bool value = cast(dimSizes[0]).getInt() == 1; rewriter.replaceOpWithNewOp( op, dstType, - DenseIntElementsAttr::get( - VectorType::get(ArrayRef{}, rewriter.getI1Type()), - ArrayRef{value})); + DenseIntElementsAttr::get(VectorType::get({}, rewriter.getI1Type()), + value)); return success(); } - // Scalable constant masks can only be lowered for the "none set" case. - if (cast(dstType).isScalable()) { - rewriter.replaceOpWithNewOp( - op, DenseElementsAttr::get(dstType, false)); - return success(); - } - - int64_t trueDim = std::min(dstType.getDimSize(0), - cast(dimSizes[0]).getInt()); + int64_t trueDimSize = cast(dimSizes[0]).getInt(); if (rank == 1) { - // Express constant 1-D case in explicit vector form: - // [T,..,T,F,..,F]. - SmallVector values(dstType.getDimSize(0)); - for (int64_t d = 0; d < trueDim; d++) - values[d] = true; - rewriter.replaceOpWithNewOp( - op, dstType, rewriter.getBoolVectorAttr(values)); + if (trueDimSize == 0 || trueDimSize == dstType.getDimSize(0)) { + // Use constant splat for 'all set' or 'none set' dims. + // This produces correct code for scalable dimensions (it will lower to + // a constant splat). + rewriter.replaceOpWithNewOp( + op, DenseElementsAttr::get(dstType, trueDimSize != 0)); + } else { + // Express constant 1-D case in explicit vector form: + // [T,..,T,F,..,F]. + // Note: The verifier would reject this case for scalable vectors. + SmallVector values(dstType.getDimSize(0), false); + for (int64_t d = 0; d < trueDimSize; d++) + values[d] = true; + rewriter.replaceOpWithNewOp( + op, dstType, rewriter.getBoolVectorAttr(values)); + } return success(); } - VectorType lowType = - VectorType::get(dstType.getShape().drop_front(), eltType); - SmallVector newDimSizes; - for (int64_t r = 1; r < rank; r++) - newDimSizes.push_back(cast(dimSizes[r]).getInt()); + if (dstType.getScalableDims().front()) + return rewriter.notifyMatchFailure( + op, "Cannot unroll leading scalable dim in dstType"); + + VectorType lowType = VectorType::Builder(dstType).dropDim(0); Value trueVal = rewriter.create( - loc, lowType, rewriter.getI64ArrayAttr(newDimSizes)); + loc, lowType, rewriter.getArrayAttr(dimSizes.getValue().drop_front())); Value result = rewriter.create( loc, dstType, rewriter.getZeroAttr(dstType)); - for (int64_t d = 0; d < trueDim; d++) + for (int64_t d = 0; d < trueDimSize; d++) result = rewriter.create(loc, dstType, trueVal, result, d); rewriter.replaceOp(op, result); diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index 7b29ef44c1f2f..083b3af90e8c5 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1819,16 +1819,53 @@ func.func @genbool_1d() -> vector<8xi1> { // ----- -func.func @genbool_1d_scalable() -> vector<[8]xi1> { +func.func @genbool_1d_scalable_all_false() -> vector<[8]xi1> { %0 = vector.constant_mask [0] : vector<[8]xi1> return %0 : vector<[8]xi1> } -// CHECK-LABEL: func @genbool_1d_scalable +// CHECK-LABEL: func @genbool_1d_scalable_all_false // CHECK: %[[VAL_0:.*]] = arith.constant dense : vector<[8]xi1> // CHECK: return %[[VAL_0]] : vector<[8]xi1> // ----- +func.func @genbool_1d_scalable_all_true() -> vector<[8]xi1> { + %0 = vector.constant_mask [8] : vector<[8]xi1> + return %0 : vector<[8]xi1> +} +// CHECK-LABEL: func @genbool_1d_scalable_all_true +// CHECK: %[[VAL_0:.*]] = arith.constant dense : vector<[8]xi1> +// CHECK: return %[[VAL_0]] : vector<[8]xi1> + +// ----- + +func.func @genbool_2d_trailing_scalable() -> vector<4x[4]xi1> { + %0 = vector.constant_mask [2, 4] : vector<4x[4]xi1> + return %0 : vector<4x[4]xi1> +} +// CHECK-LABEL: func.func @genbool_2d_trailing_scalable +// CHECK: %[[VAL_0:.*]] = arith.constant dense : vector<[4]xi1> +// CHECK: %[[VAL_1:.*]] = arith.constant dense : vector<4x[4]xi1> +// CHECK: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[VAL_1]] : vector<4x[4]xi1> to !llvm.array<4 x vector<[4]xi1>> +// CHECK: %[[VAL_3:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_2]][0] : !llvm.array<4 x vector<[4]xi1>> +// CHECK: %[[VAL_4:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_3]][1] : !llvm.array<4 x vector<[4]xi1>> +// CHECK: %[[VAL_5:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : !llvm.array<4 x vector<[4]xi1>> to vector<4x[4]xi1> +// CHECK: return %[[VAL_5]] : vector<4x[4]xi1> + +// ----- + +/// Currently, this is not supported as generating the mask would require +/// unrolling the leading scalable dimension at compile time. +func.func @cannot_genbool_2d_leading_scalable() -> vector<[4]x4xi1> { + %0 = vector.constant_mask [4, 2] : vector<[4]x4xi1> + return %0 : vector<[4]x4xi1> +} +// CHECK-LABEL: func.func @cannot_genbool_2d_leading_scalable +// CHECK: %[[VAL_0:.*]] = vector.constant_mask [4, 2] : vector<[4]x4xi1> +// CHECK: return %[[VAL_0]] : vector<[4]x4xi1> + +// ----- + func.func @genbool_2d() -> vector<4x4xi1> { %v = vector.constant_mask [2, 2] : vector<4x4xi1> return %v: vector<4x4xi1> diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index 50119c2b4a362..26772b9294935 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -995,7 +995,7 @@ func.func @constant_mask_with_zero_mask_dim_size() { // ----- func.func @constant_mask_scalable_non_zero_dim_size() { - // expected-error@+1 {{expected mask dim sizes for scalable masks to be 0}} + // expected-error@+1 {{only supports 'none set' or 'all set' scalable dimensions}} %0 = vector.constant_mask [2] : vector<[8]xi1> } diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir index 4ea4379372e83..18a4c202edfb0 100644 --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -448,6 +448,10 @@ func.func @constant_vector_mask() { %0 = vector.constant_mask [3, 2] : vector<4x3xi1> // CHECK: vector.constant_mask [0] : vector<[4]xi1> %1 = vector.constant_mask [0] : vector<[4]xi1> + // CHECK: vector.constant_mask [4] : vector<[4]xi1> + %2 = vector.constant_mask [4] : vector<[4]xi1> + // CHECK: vector.constant_mask [1, 4] : vector<2x[4]xi1> + %3 = vector.constant_mask [1, 4] : vector<2x[4]xi1> return }