From b6015b0f27e5612fde28383aa2c1e3e13934bf34 Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Mon, 18 Sep 2023 12:20:19 +0000 Subject: [PATCH 1/2] [mlir][VectorOps] Extend vector.constant_mask to support 'all true' scalable dims This extends vector.constant_mask so that mask dim sizes that correspond to a scalable dimension are treated as if they're implicitly multipled by vscale. Currently this is limited to mask dim sizes of 0 or the size of the dim/vscale. This allows constant masks to represent all true and all false scalable masks (and some variations): // All true scalable mask %mask = vector.constant_mask [8] : vector<[8]xi1> // All false scalable mask %mask = vector.constant_mask [0] : vector<[8]xi1> // First two scalable rows %mask = vector.constant_mask [2,4] : vector<4x[4]xi1> --- .../mlir/Dialect/Vector/IR/VectorOps.td | 5 +- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 21 ++++---- .../Vector/Transforms/LowerVectorMask.cpp | 51 +++++++++---------- .../VectorToLLVM/vector-to-llvm.mlir | 43 +++++++++++++++- mlir/test/Dialect/Vector/invalid.mlir | 2 +- mlir/test/Dialect/Vector/ops.mlir | 6 ++- 6 files changed, 84 insertions(+), 44 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 28b5864914f69..64fbd722a4f02 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -2248,7 +2248,10 @@ def Vector_ConstantMaskOp : define a hyper-rectangular region within which elements values are set to 1 (otherwise element values are set to 0). Each value of 'mask_dim_sizes' must be non-negative and not greater than the size of the corresponding vector - dimension (as opposed to vector.create_mask which allows this). + dimension (as opposed to vector.create_mask which allows this). Sizes that + correspond to scalable dimensions are implicitly multiplied by vscale, + though currently only zero (none set) or the size of the dim/vscale + (all set) are supported. Example: diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index a8ad05f7bc1ca..3c68cb26fb55a 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5320,13 +5320,18 @@ LogicalResult ConstantMaskOp::verify() { // Verify that each array attr element is in bounds of corresponding vector // result dimension size. auto resultShape = resultType.getShape(); + auto resultScalableDims = resultType.getScalableDims(); SmallVector maskDimSizes; - for (const auto &it : llvm::enumerate(getMaskDimSizes())) { - int64_t attrValue = llvm::cast(it.value()).getInt(); - if (attrValue < 0 || attrValue > resultShape[it.index()]) + for (const auto [index, intAttr] : llvm::enumerate(getMaskDimSizes())) { + int64_t maskDimSize = llvm::cast(intAttr).getInt(); + if (maskDimSize < 0 || maskDimSize > resultShape[index]) return emitOpError( "array attr of size out of bounds of vector result dimension size"); - maskDimSizes.push_back(attrValue); + if (resultScalableDims[index] && maskDimSize != 0 && + maskDimSize != resultShape[index]) + return emitOpError( + "only supports 'none set' or 'all set' scalable dimensions"); + maskDimSizes.push_back(maskDimSize); } // Verify that if one mask dim size is zero, they all should be zero (because // the mask region is a conjunction of each mask dimension interval). @@ -5335,14 +5340,6 @@ LogicalResult ConstantMaskOp::verify() { if (anyZeros && !allZeros) return emitOpError("expected all mask dim sizes to be zeros, " "as a result of conjunction with zero mask dim"); - // Verify that if the mask type is scalable, dimensions should be zero because - // constant scalable masks can only be defined for the "none set" or "all set" - // cases, and there is no VLA way to define an "all set" case for - // `vector.constant_mask`. In the future, a convention could be established - // to decide if a specific dimension value could be considered as "all set". - if (resultType.isScalable() && - llvm::cast(getMaskDimSizes()[0]).getInt() != 0) - return emitOpError("expected mask dim sizes for scalable masks to be 0"); return success(); } diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp index 9a828ec0b845e..418dc6786a76e 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp @@ -105,7 +105,6 @@ class ConstantMaskOpLowering : public OpRewritePattern { PatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto dstType = op.getType(); - auto eltType = dstType.getElementType(); auto dimSizes = op.getMaskDimSizes(); int64_t rank = dstType.getRank(); @@ -115,43 +114,41 @@ class ConstantMaskOpLowering : public OpRewritePattern { bool value = cast(dimSizes[0]).getInt() == 1; rewriter.replaceOpWithNewOp( op, dstType, - DenseIntElementsAttr::get( - VectorType::get(ArrayRef{}, rewriter.getI1Type()), - ArrayRef{value})); + DenseIntElementsAttr::get(VectorType::get({}, rewriter.getI1Type()), + value)); return success(); } - // Scalable constant masks can only be lowered for the "none set" case. - if (cast(dstType).isScalable()) { - rewriter.replaceOpWithNewOp( - op, DenseElementsAttr::get(dstType, false)); - return success(); - } - - int64_t trueDim = std::min(dstType.getDimSize(0), - cast(dimSizes[0]).getInt()); + int64_t trueDimSize = cast(dimSizes[0]).getInt(); if (rank == 1) { - // Express constant 1-D case in explicit vector form: - // [T,..,T,F,..,F]. - SmallVector values(dstType.getDimSize(0)); - for (int64_t d = 0; d < trueDim; d++) - values[d] = true; - rewriter.replaceOpWithNewOp( - op, dstType, rewriter.getBoolVectorAttr(values)); + if (trueDimSize == 0 || trueDimSize == dstType.getDimSize(0)) { + // Use constant splat for 'all set' or 'none set' dims. + // This produces correct code for scalable dimensions. + rewriter.replaceOpWithNewOp( + op, DenseElementsAttr::get(dstType, trueDimSize != 0)); + } else { + // Express constant 1-D case in explicit vector form: + // [T,..,T,F,..,F]. + SmallVector values(dstType.getDimSize(0)); + for (int64_t d = 0; d < trueDimSize; d++) + values[d] = true; + rewriter.replaceOpWithNewOp( + op, dstType, rewriter.getBoolVectorAttr(values)); + } return success(); } - VectorType lowType = - VectorType::get(dstType.getShape().drop_front(), eltType); - SmallVector newDimSizes; - for (int64_t r = 1; r < rank; r++) - newDimSizes.push_back(cast(dimSizes[r]).getInt()); + if (dstType.getScalableDims().front()) + return rewriter.notifyMatchFailure( + op, "Cannot unroll leading scalable dim in dstType"); + + VectorType lowType = VectorType::Builder(dstType).dropDim(0); Value trueVal = rewriter.create( - loc, lowType, rewriter.getI64ArrayAttr(newDimSizes)); + loc, lowType, rewriter.getArrayAttr(dimSizes.getValue().drop_front())); Value result = rewriter.create( loc, dstType, rewriter.getZeroAttr(dstType)); - for (int64_t d = 0; d < trueDim; d++) + for (int64_t d = 0; d < trueDimSize; d++) result = rewriter.create(loc, dstType, trueVal, result, d); rewriter.replaceOp(op, result); diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index 7b29ef44c1f2f..27bd5b5ea0eed 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1819,16 +1819,55 @@ func.func @genbool_1d() -> vector<8xi1> { // ----- -func.func @genbool_1d_scalable() -> vector<[8]xi1> { +func.func @genbool_1d_scalable_pfalse() -> vector<[8]xi1> { %0 = vector.constant_mask [0] : vector<[8]xi1> return %0 : vector<[8]xi1> } -// CHECK-LABEL: func @genbool_1d_scalable +// CHECK-LABEL: func @genbool_1d_scalable_pfalse // CHECK: %[[VAL_0:.*]] = arith.constant dense : vector<[8]xi1> // CHECK: return %[[VAL_0]] : vector<[8]xi1> // ----- +func.func @genbool_1d_scalable_ptrue() -> vector<[8]xi1> { + %0 = vector.constant_mask [8] : vector<[8]xi1> + return %0 : vector<[8]xi1> +} +// CHECK-LABEL: func @genbool_1d_scalable_ptrue +// CHECK: %[[VAL_0:.*]] = arith.constant dense : vector<[8]xi1> +// CHECK: return %[[VAL_0]] : vector<[8]xi1> + +// ----- + +func.func @genbool_2d_scalable() -> vector<4x[4]xi1> { + %0 = vector.constant_mask [2, 4] : vector<4x[4]xi1> + return %0 : vector<4x[4]xi1> +} +// CHECK-LABEL: func.func @genbool_2d_scalable() -> vector<4x[4]xi1> { +// CHECK: %[[VAL_0:.*]] = arith.constant dense : vector<[4]xi1> +// CHECK: %[[VAL_1:.*]] = arith.constant dense : vector<4x[4]xi1> +// CHECK: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[VAL_1]] : vector<4x[4]xi1> to !llvm.array<4 x vector<[4]xi1>> +// CHECK: %[[VAL_3:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_2]][0] : !llvm.array<4 x vector<[4]xi1>> +// CHECK: %[[VAL_4:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_3]][1] : !llvm.array<4 x vector<[4]xi1>> +// CHECK: %[[VAL_5:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : !llvm.array<4 x vector<[4]xi1>> to vector<4x[4]xi1> +// CHECK: return %[[VAL_5]] : vector<4x[4]xi1> +// CHECK: } + +// ----- + +/// Currently, this is not supported as generating the mask would require +/// unrolling the leading scalable dimension at compile time. +func.func @cannot_genbool_2d_leading_scalable() -> vector<[4]x4xi1> { + %0 = vector.constant_mask [4, 2] : vector<[4]x4xi1> + return %0 : vector<[4]x4xi1> +} +// CHECK-LABEL: func.func @cannot_genbool_2d_leading_scalable() -> vector<[4]x4xi1> { +// CHECK: %[[VAL_0:.*]] = vector.constant_mask [4, 2] : vector<[4]x4xi1> +// CHECK: return %[[VAL_0]] : vector<[4]x4xi1> +// CHECK: } + +// ----- + func.func @genbool_2d() -> vector<4x4xi1> { %v = vector.constant_mask [2, 2] : vector<4x4xi1> return %v: vector<4x4xi1> diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index 50119c2b4a362..26772b9294935 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -995,7 +995,7 @@ func.func @constant_mask_with_zero_mask_dim_size() { // ----- func.func @constant_mask_scalable_non_zero_dim_size() { - // expected-error@+1 {{expected mask dim sizes for scalable masks to be 0}} + // expected-error@+1 {{only supports 'none set' or 'all set' scalable dimensions}} %0 = vector.constant_mask [2] : vector<[8]xi1> } diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir index 4ea4379372e83..96c56946cd1cf 100644 --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -448,6 +448,10 @@ func.func @constant_vector_mask() { %0 = vector.constant_mask [3, 2] : vector<4x3xi1> // CHECK: vector.constant_mask [0] : vector<[4]xi1> %1 = vector.constant_mask [0] : vector<[4]xi1> + // CHECK: vector.constant_mask [4] : vector<[4]xi1> + %2 = vector.constant_mask [4] : vector<[4]xi1> + // CHECK: vector.constant_mask [1, 4] : vector<2x[4]xi1> + %3 = vector.constant_mask [1, 4] : vector<2x[4]xi1> return } @@ -1003,7 +1007,7 @@ func.func @contraction_masked_scalable(%A: vector<3x4xf32>, %C: vector<3x[8]xf32>, %M : vector<3x[8]x4xi1>) -> vector<3x[8]xf32> { // CHECK: vector.mask %[[M]] { vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[A]], %[[B]], %[[C]] : vector<3x4xf32>, vector<4x[8]xf32> into vector<3x[8]xf32> } : vector<3x[8]x4xi1> -> vector<3x[8]xf32> - %0 = vector.mask %M { vector.contract #matmat_trait %A, %B, %C : vector<3x4xf32>, vector<4x[8]xf32> into vector<3x[8]xf32> } + %0 = vector.mask %M { vector.contract #matmat_trait %A, %B, %C : vector<3x4xf32>, vector<4x[8]xf32> into vector<3x[8]xf32> } : vector<3x[8]x4xi1> -> vector<3x[8]xf32> return %0 : vector<3x[8]xf32> } From 91e4da628e8278694aeb2af834616f7f0e1a769e Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Wed, 20 Sep 2023 11:07:47 +0000 Subject: [PATCH 2/2] Review fixups --- .../Vector/Transforms/LowerVectorMask.cpp | 6 ++++-- .../Conversion/VectorToLLVM/vector-to-llvm.mlir | 16 +++++++--------- mlir/test/Dialect/Vector/ops.mlir | 2 +- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp index 418dc6786a76e..95b5ea011c825 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp @@ -124,13 +124,15 @@ class ConstantMaskOpLowering : public OpRewritePattern { if (rank == 1) { if (trueDimSize == 0 || trueDimSize == dstType.getDimSize(0)) { // Use constant splat for 'all set' or 'none set' dims. - // This produces correct code for scalable dimensions. + // This produces correct code for scalable dimensions (it will lower to + // a constant splat). rewriter.replaceOpWithNewOp( op, DenseElementsAttr::get(dstType, trueDimSize != 0)); } else { // Express constant 1-D case in explicit vector form: // [T,..,T,F,..,F]. - SmallVector values(dstType.getDimSize(0)); + // Note: The verifier would reject this case for scalable vectors. + SmallVector values(dstType.getDimSize(0), false); for (int64_t d = 0; d < trueDimSize; d++) values[d] = true; rewriter.replaceOpWithNewOp( diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index 27bd5b5ea0eed..083b3af90e8c5 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1819,31 +1819,31 @@ func.func @genbool_1d() -> vector<8xi1> { // ----- -func.func @genbool_1d_scalable_pfalse() -> vector<[8]xi1> { +func.func @genbool_1d_scalable_all_false() -> vector<[8]xi1> { %0 = vector.constant_mask [0] : vector<[8]xi1> return %0 : vector<[8]xi1> } -// CHECK-LABEL: func @genbool_1d_scalable_pfalse +// CHECK-LABEL: func @genbool_1d_scalable_all_false // CHECK: %[[VAL_0:.*]] = arith.constant dense : vector<[8]xi1> // CHECK: return %[[VAL_0]] : vector<[8]xi1> // ----- -func.func @genbool_1d_scalable_ptrue() -> vector<[8]xi1> { +func.func @genbool_1d_scalable_all_true() -> vector<[8]xi1> { %0 = vector.constant_mask [8] : vector<[8]xi1> return %0 : vector<[8]xi1> } -// CHECK-LABEL: func @genbool_1d_scalable_ptrue +// CHECK-LABEL: func @genbool_1d_scalable_all_true // CHECK: %[[VAL_0:.*]] = arith.constant dense : vector<[8]xi1> // CHECK: return %[[VAL_0]] : vector<[8]xi1> // ----- -func.func @genbool_2d_scalable() -> vector<4x[4]xi1> { +func.func @genbool_2d_trailing_scalable() -> vector<4x[4]xi1> { %0 = vector.constant_mask [2, 4] : vector<4x[4]xi1> return %0 : vector<4x[4]xi1> } -// CHECK-LABEL: func.func @genbool_2d_scalable() -> vector<4x[4]xi1> { +// CHECK-LABEL: func.func @genbool_2d_trailing_scalable // CHECK: %[[VAL_0:.*]] = arith.constant dense : vector<[4]xi1> // CHECK: %[[VAL_1:.*]] = arith.constant dense : vector<4x[4]xi1> // CHECK: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[VAL_1]] : vector<4x[4]xi1> to !llvm.array<4 x vector<[4]xi1>> @@ -1851,7 +1851,6 @@ func.func @genbool_2d_scalable() -> vector<4x[4]xi1> { // CHECK: %[[VAL_4:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_3]][1] : !llvm.array<4 x vector<[4]xi1>> // CHECK: %[[VAL_5:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : !llvm.array<4 x vector<[4]xi1>> to vector<4x[4]xi1> // CHECK: return %[[VAL_5]] : vector<4x[4]xi1> -// CHECK: } // ----- @@ -1861,10 +1860,9 @@ func.func @cannot_genbool_2d_leading_scalable() -> vector<[4]x4xi1> { %0 = vector.constant_mask [4, 2] : vector<[4]x4xi1> return %0 : vector<[4]x4xi1> } -// CHECK-LABEL: func.func @cannot_genbool_2d_leading_scalable() -> vector<[4]x4xi1> { +// CHECK-LABEL: func.func @cannot_genbool_2d_leading_scalable // CHECK: %[[VAL_0:.*]] = vector.constant_mask [4, 2] : vector<[4]x4xi1> // CHECK: return %[[VAL_0]] : vector<[4]x4xi1> -// CHECK: } // ----- diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir index 96c56946cd1cf..18a4c202edfb0 100644 --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -1007,7 +1007,7 @@ func.func @contraction_masked_scalable(%A: vector<3x4xf32>, %C: vector<3x[8]xf32>, %M : vector<3x[8]x4xi1>) -> vector<3x[8]xf32> { // CHECK: vector.mask %[[M]] { vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[A]], %[[B]], %[[C]] : vector<3x4xf32>, vector<4x[8]xf32> into vector<3x[8]xf32> } : vector<3x[8]x4xi1> -> vector<3x[8]xf32> - %0 = vector.mask %M { vector.contract #matmat_trait %A, %B, %C : vector<3x4xf32>, vector<4x[8]xf32> into vector<3x[8]xf32> } + %0 = vector.mask %M { vector.contract #matmat_trait %A, %B, %C : vector<3x4xf32>, vector<4x[8]xf32> into vector<3x[8]xf32> } : vector<3x[8]x4xi1> -> vector<3x[8]xf32> return %0 : vector<3x[8]xf32> }