From e217711238356ff53520a01b32630c8df59da8ec Mon Sep 17 00:00:00 2001 From: Cullen Rhodes Date: Sun, 15 Oct 2023 08:09:57 +0000 Subject: [PATCH 1/2] [mlir][ArmSME] Add optional padding and mask operands to tile_load Padding and mask are optional, but if one is specified both must be specified. This is consistent with vector.transfer_read. --- .../mlir/Dialect/ArmSME/IR/ArmSMEOps.td | 50 +++++++++++++++++-- mlir/test/Dialect/ArmSME/invalid.mlir | 44 ++++++++++++++++ mlir/test/Dialect/ArmSME/roundtrip.mlir | 10 ++++ 3 files changed, 101 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td index 9b9dbff10ea2d..31287c2c259db 100644 --- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td +++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td @@ -231,7 +231,24 @@ def ZeroOp : ArmSME_Op<"zero", [Pure]> { let assemblyFormat = "attr-dict `:` type($res)"; } -def TileLoadOp : ArmSME_Op<"tile_load"> { +def TileLoadOp : ArmSME_Op<"tile_load", [ + AttrSizedOperandSegments, + TypesMatchWith< + "padding type matches element type of result (if present)", + "result", "padding", + "::llvm::cast($_self).getElementType()", + "!getPadding() || std::equal_to<>()" + >, + TypesMatchWith< + "mask has i1 element type and same shape as result (if present)", + "result", "mask", + "VectorType(" + "VectorType::Builder(" + "::llvm::cast($_self)" + ").setElementType(IntegerType::get($_self.getContext(), 1)))", + "!getMask() || std::equal_to<>()" + > +]> { let summary = "Tile load operation"; let description = [{ Loads a 2D SME "virtual tile" from memory defined by a base and indices, @@ -242,6 +259,16 @@ def TileLoadOp : ArmSME_Op<"tile_load"> { dimensions, since the operation is scalable, and the element type must be a scalar that matches the element type of the result. + An optional SSA value `padding` of the same elemental type as the MemRef is + provided to specify a fallback value in the case of masking. + + An optional SSA value `mask` may be specified to mask out elements read + from the MemRef. The `mask` type is an `i1` vector with a shape that + matches how elements are read from the MemRef. Elements whose corresponding + mask element is `0` are masked out and replaced with `padding`. + + If either `padding` or `mask` are specified, both must be specified. + Example 1: Load an 8-bit element ZA tile with horizontal layout (default) from memory (ZA0.B). ```mlir %tile = arm_sme.tile_load %base[%c0, %c0] : memref, vector<[16]x[16]xi8> @@ -256,10 +283,16 @@ def TileLoadOp : ArmSME_Op<"tile_load"> { ```mlir %tile = arm_sme.tile_load %base[%c0, %c0] layout : memref, vector<[1]x[1]xi128> ``` + + Example 4: Masked load of int 32-bit element ZA tile with horizontal layout (default) from memory. + ```mlir + %tile = arm_sme.tile_load %base[%c0, %c0], %pad, %mask : memref, vector<[4]x[4]xf32> + ``` }]; let arguments = (ins Arg:$base, Variadic:$indices, + Optional:$padding, Optional:$mask, ArmSME_TileSliceLayoutAttr:$layout ); let results = (outs SMETile:$result); @@ -273,9 +306,20 @@ def TileLoadOp : ArmSME_Op<"tile_load"> { } }]; + let builders = [ + OpBuilder<(ins "VectorType":$resultType, "Value":$base, + "ValueRange":$indices, "TileSliceLayout":$layout), [{ + build($_builder, $_state, resultType, base, indices, {}, {}, layout); + }]>, + OpBuilder<(ins "VectorType":$resultType, "Value":$base, + "ValueRange":$indices), [{ + build($_builder, $_state, resultType, base, indices, {}, {}, {}); + }]>, + ]; + let assemblyFormat = - "$base `[` $indices `]` (`layout` `` $layout^)? attr-dict " - "`:` type($base) `,` type($result)"; + "$base `[` $indices `]` (`,` $padding `,` $mask^)? (`layout` `` $layout^)?" + "attr-dict `:` type($base) `,` type($result)"; } def TileStoreOp : ArmSME_Op<"tile_store"> { diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir index 431009b1b9ede..9229f0415c076 100644 --- a/mlir/test/Dialect/ArmSME/invalid.mlir +++ b/mlir/test/Dialect/ArmSME/invalid.mlir @@ -1,5 +1,9 @@ // RUN: mlir-opt %s -split-input-file -verify-diagnostics +//===----------------------------------------------------------------------===// +// arm_sme.cast_tile_to_vector +//===----------------------------------------------------------------------===// + // ----- func.func @arm_sme_cast_tile_to_vector__bad_tile_id_bitwidth(%tile_id : i8) -> vector<[8]x[8]xi16> { @@ -48,6 +52,10 @@ func.func @arm_sme_cast_tile_to_vector_bad_shape(%tile_id : i8) -> vector<[4]x[1 return %0 : vector<[4]x[16]xi8> } +//===----------------------------------------------------------------------===// +// arm_sme.cast_vector_to_tile +//===----------------------------------------------------------------------===// + // ----- func.func @arm_sme_cast_vector_to_tile__bad_tile_id_bitwidth(%vector : vector<[1]x[1]xi128>) -> i32 { @@ -64,6 +72,10 @@ func.func @arm_sme_cast_vector_to_tile__bad_rank_1d(%vector : vector<[16]xi8>) - return %0 : i8 } +//===----------------------------------------------------------------------===// +// arm_sme.get_tile_id +//===----------------------------------------------------------------------===// + // ----- func.func @arm_sme_get_tile_id__bad_type() -> i1 { @@ -72,6 +84,10 @@ func.func @arm_sme_get_tile_id__bad_type() -> i1 { return %0 : i1 } +//===----------------------------------------------------------------------===// +// arm_sme.move_vector_to_tile_slice +//===----------------------------------------------------------------------===// + // ----- func.func @arm_sme_move_vector_to_tile_slice_i8__bad_vector_type(%vector : vector<[8]xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) -> vector<[16]x[16]xi8> { @@ -90,6 +106,10 @@ func.func @arm_sme_move_vector_to_tile_slice_f32__bad_vector_type(%vector : vect return %0 : vector<[4]x[4]xf32> } +//===----------------------------------------------------------------------===// +// arm_sme.move_tile_slice_to_vector +//===----------------------------------------------------------------------===// + // ----- func.func @arm_sme_move_tile_slice_to_vector__bad_result_type(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[2]xf64> { @@ -97,3 +117,27 @@ func.func @arm_sme_move_tile_slice_to_vector__bad_result_type(%tile : vector<[4] %0 = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[4]x[4]xf32> return %0 : vector<[2]xf64> } + +//===----------------------------------------------------------------------===// +// arm_sme.tile_load +//===----------------------------------------------------------------------===// + +// ----- + +func.func @arm_sme_tile_load__bad_padding_type(%src : memref, %pad : f32, %mask : vector<[2]x[2]xi1>) { + %c0 = arith.constant 0 : index + // expected-note@-2 {{prior use here}} + // expected-error@+1 {{use of value '%pad' expects different type than prior uses: 'f64' vs 'f32'}} + %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref, vector<[2]x[2]xf64> + return +} + +// ----- + +func.func @arm_sme_tile_load__bad_mask_type(%src : memref, %pad : f64, %mask : vector<[4]x[4]xi1>) { + %c0 = arith.constant 0 : index + // expected-note@-2 {{prior use here}} + // expected-error@+1 {{use of value '%mask' expects different type than prior uses: 'vector<[2]x[2]xi1>' vs 'vector<[4]x[4]xi1>}} + %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref, vector<[2]x[2]xf64> + return +} diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir index e5ba81eff8360..6866137267dc6 100644 --- a/mlir/test/Dialect/ArmSME/roundtrip.mlir +++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir @@ -438,6 +438,16 @@ func.func @arm_sme_tile_load_ver_f64(%src : memref) { // ----- +/// Padding and mask are optional +func.func @arm_sme_tile_load_hor_pad_f64(%src : memref, %pad : f64, %mask : vector<[2]x[2]xi1>) { + // CHECK: arm_sme.tile_load %{{.*}}[{{.*}}], {{.*}}, {{.*}} : memref, vector<[2]x[2]xf64> + %c0 = arith.constant 0 : index + %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref, vector<[2]x[2]xf64> + return +} + +// ----- + /// Layout is optional and horizontal is the default, verify it's still parsed. func.func @arm_sme_tile_load_explicit_hor(%src : memref) { // CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref, vector<[16]x[16]xi8> From dbc1c024f9ffd36f292301fc62d72748688297c3 Mon Sep 17 00:00:00 2001 From: Cullen Rhodes Date: Mon, 30 Oct 2023 11:17:46 +0000 Subject: [PATCH 2/2] Changes: * Use OptionalTypesMatchWith * Add constraint to verify both padding and mask are specified, as well as test. --- .../mlir/Dialect/ArmSME/IR/ArmSMEOps.td | 20 ++++++++++--------- mlir/test/Dialect/ArmSME/invalid.mlir | 9 +++++++++ 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td index 31287c2c259db..b30d0fdb866bd 100644 --- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td +++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td @@ -233,21 +233,23 @@ def ZeroOp : ArmSME_Op<"zero", [Pure]> { def TileLoadOp : ArmSME_Op<"tile_load", [ AttrSizedOperandSegments, - TypesMatchWith< - "padding type matches element type of result (if present)", + OptionalTypesMatchWith< + "padding type matches element type of result", "result", "padding", - "::llvm::cast($_self).getElementType()", - "!getPadding() || std::equal_to<>()" + "::llvm::cast($_self).getElementType()" >, - TypesMatchWith< - "mask has i1 element type and same shape as result (if present)", + OptionalTypesMatchWith< + "mask has i1 element type and same shape as result", "result", "mask", "VectorType(" "VectorType::Builder(" "::llvm::cast($_self)" - ").setElementType(IntegerType::get($_self.getContext(), 1)))", - "!getMask() || std::equal_to<>()" - > + ").setElementType(IntegerType::get($_self.getContext(), 1)))" + >, + PredOpTrait< + "both `padding` and `mask` should be provided or neither", + CPred<"bool(getPadding()) == bool(getMask())"> + >, ]> { let summary = "Tile load operation"; let description = [{ diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir index 9229f0415c076..25c62f78d8435 100644 --- a/mlir/test/Dialect/ArmSME/invalid.mlir +++ b/mlir/test/Dialect/ArmSME/invalid.mlir @@ -141,3 +141,12 @@ func.func @arm_sme_tile_load__bad_mask_type(%src : memref, %pad : f64, %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref, vector<[2]x[2]xf64> return } + +// ----- + +func.func @arm_sme_tile_load__pad_but_no_mask(%src : memref, %pad : f64) { + %c0 = arith.constant 0 : index + // expected-error@+1 {{op failed to verify that both `padding` and `mask` should be provided or neither}} + %tile = arm_sme.tile_load %src[%c0, %c0], %pad, : memref, vector<[2]x[2]xf64> + return +}