-
Notifications
You must be signed in to change notification settings - Fork 12.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][ArmSME] Split the Op definition (nfc) #67985
Conversation
@llvm/pr-subscribers-mlir-llvm @llvm/pr-subscribers-mlir-sme ChangesMove the definitions of LLVM intrinsic Ops for ArmSME into a dedicated This change will allow us to refactor the ArmSME dialect documentation. The documentation will be updated in a forthcoming patch. Only Patch is 21.42 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/67985.diff 8 Files Affected:
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
index f947fc8fe1631b8..dcb18be4f05ead4 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
@@ -31,4 +31,7 @@
#define GET_OP_CLASSES
#include "mlir/Dialect/ArmSME/IR/ArmSME.h.inc"
+#define GET_OP_CLASSES
+#include "mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.h.inc"
+
#endif // MLIR_DIALECT_ARMSME_IR_ARMSME_H
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
index 66a432ea1b171e0..4d89f62e28e0ac4 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
@@ -14,33 +14,13 @@
#ifndef ARMSME_OPS
#define ARMSME_OPS
+include "ArmSMEBase.td"
include "mlir/IR/EnumAttr.td"
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
-//===----------------------------------------------------------------------===//
-// ArmSME dialect definition
-//===----------------------------------------------------------------------===//
-
-def ArmSME_Dialect : Dialect {
- let name = "arm_sme";
- let cppNamespace = "::mlir::arm_sme";
- let summary = "Basic dialect to target Arm SME architectures";
- let description = [{
- This dialect contains the definitions necessary to target Arm SME
- scalable matrix operations.
-
- Sources:
- https://developer.arm.com/documentation/ddi0616
- https://developer.arm.com/documentation/ddi0602/2023-03/SME-Instructions
- }];
- let dependentDialects = ["scf::SCFDialect", "vector::VectorDialect",
- "memref::MemRefDialect"];
- let useDefaultAttributePrinterParser = 1;
-}
-
//===----------------------------------------------------------------------===//
// ArmSME type definitions
//===----------------------------------------------------------------------===//
@@ -65,12 +45,6 @@ def nxnxv2f64 : SMETileType<F64, [2, 2 ], "vector<[2]x[2]xf64>">;
def SMETile : AnyTypeOf<[nxnxv16i8, nxnxv8i16, nxnxv4i32, nxnxv2i64, nxnxv1i128,
nxnxv8f16, nxnxv8bf16, nxnxv4f32, nxnxv2f64]>;
-def SVEVector : ScalableVectorOfRankAndLengthAndType<
- [1], [16, 8, 4, 2, 1], [I8, I16, I32, I64, I128, F16, BF16, F32, F64]>;
-
-def SVEPredicate : ScalableVectorOfRankAndLengthAndType<
- [1], [16, 8, 4, 2, 1], [I1]>;
-
// A type constraint that verifies the bitwidth of the scalar integer returned
// from 'arm_sme.get_tile_id' matches the element bitwidth of a "virtual tile".
def TileElementWidthMatchesTileID : TypesMatchWith<
@@ -538,123 +512,4 @@ def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [Pure,
}];
}
-//===----------------------------------------------------------------------===//
-// ArmSME Intrinsic op definitions
-//===----------------------------------------------------------------------===//
-
-def MOPPredicate : ScalableVectorOfLengthAndType<[16, 8, 4, 2], [I1]>;
-def MOPVector : ScalableVectorOfLengthAndType<[16, 8, 4, 2],
- [I8, I16, BF16, F16, F32, F64]>;
-def LDSTPredicate : ScalableVectorOfLengthAndType<[16, 8, 4, 2, 1], [I1]>;
-
-class ArmSME_IntrOp<string mnemonic, list<int> overloadedOperands = [],
- list<Trait> traits = [], int numResults = 0,
- list<int> overloadedResults = []>
- : LLVM_IntrOpBase<
- /*Dialect dialect=*/ArmSME_Dialect,
- /*string opName=*/"intr." # mnemonic,
- /*string enumName=*/"aarch64_sme_" # !subst(".", "_", mnemonic),
- /*list<int> overloadedResults=*/overloadedResults,
- /*list<int> overloadedOperands=*/overloadedOperands,
- /*list<Trait> traits=*/traits,
- /*int numResults=*/numResults>;
-
-// Zero
-def LLVM_aarch64_sme_zero : ArmSME_IntrOp<"zero">,
- Arguments<(ins Arg<I32, "Tile mask">)>;
-
-// MOP's
-class ArmSME_IntrMopOverloadedOp<string mnemonic>
- : ArmSME_IntrOp<mnemonic, [4]>,
- Arguments<(ins Arg<I32, "Virtual tile ID">,
- Arg<MOPPredicate, "LHS predicate">,
- Arg<MOPPredicate, "RHS predicate">,
- Arg<MOPVector, "LHS vector operand">,
- Arg<MOPVector, "RHS vector operand">)>;
-
-def LLVM_aarch64_sme_mopa : ArmSME_IntrMopOverloadedOp<"mopa">;
-def LLVM_aarch64_sme_mops : ArmSME_IntrMopOverloadedOp<"mops">;
-def LLVM_aarch64_sme_mopa_wide : ArmSME_IntrMopOverloadedOp<"mopa.wide">;
-def LLVM_aarch64_sme_mops_wide : ArmSME_IntrMopOverloadedOp<"mops.wide">;
-def LLVM_aarch64_sme_smopa_wide : ArmSME_IntrMopOverloadedOp<"smopa.wide">;
-def LLVM_aarch64_sme_smops_wide : ArmSME_IntrMopOverloadedOp<"smops.wide">;
-def LLVM_aarch64_sme_umopa_wide : ArmSME_IntrMopOverloadedOp<"umopa.wide">;
-def LLVM_aarch64_sme_umops_wide : ArmSME_IntrMopOverloadedOp<"umops.wide">;
-def LLVM_aarch64_sme_sumopa_wide : ArmSME_IntrMopOverloadedOp<"sumopa.wide">;
-def LLVM_aarch64_sme_sumops_wide : ArmSME_IntrMopOverloadedOp<"sumops.wide">;
-def LLVM_aarch64_sme_usmopa_wide : ArmSME_IntrMopOverloadedOp<"usmopa.wide">;
-def LLVM_aarch64_sme_usmops_wide : ArmSME_IntrMopOverloadedOp<"usmops.wide">;
-
-// Loads
-class ArmSME_IntrLoadOp<string mnemonic>
- : ArmSME_IntrOp<mnemonic>,
- Arguments<(ins Arg<LDSTPredicate, "Vector predicate">,
- Arg<LLVM_AnyPointer, "Load address">,
- Arg<I32, "Virtual tile ID">,
- Arg<I32, "Tile slice">)>;
-
-def LLVM_aarch64_sme_ld1b_horiz : ArmSME_IntrLoadOp<"ld1b.horiz">;
-def LLVM_aarch64_sme_ld1h_horiz : ArmSME_IntrLoadOp<"ld1h.horiz">;
-def LLVM_aarch64_sme_ld1w_horiz : ArmSME_IntrLoadOp<"ld1w.horiz">;
-def LLVM_aarch64_sme_ld1d_horiz : ArmSME_IntrLoadOp<"ld1d.horiz">;
-def LLVM_aarch64_sme_ld1q_horiz : ArmSME_IntrLoadOp<"ld1q.horiz">;
-def LLVM_aarch64_sme_ld1b_vert : ArmSME_IntrLoadOp<"ld1b.vert">;
-def LLVM_aarch64_sme_ld1h_vert : ArmSME_IntrLoadOp<"ld1h.vert">;
-def LLVM_aarch64_sme_ld1w_vert : ArmSME_IntrLoadOp<"ld1w.vert">;
-def LLVM_aarch64_sme_ld1d_vert : ArmSME_IntrLoadOp<"ld1d.vert">;
-def LLVM_aarch64_sme_ld1q_vert : ArmSME_IntrLoadOp<"ld1q.vert">;
-
-// Stores
-class ArmSME_IntrStoreOp<string mnemonic>
- : ArmSME_IntrOp<mnemonic>,
- Arguments<(ins Arg<LDSTPredicate, "Vector predicate">,
- Arg<LLVM_AnyPointer, "Store address", [MemWrite]>,
- Arg<I32, "Virtual tile ID">,
- Arg<I32, "Tile slice">)>;
-
-def LLVM_aarch64_sme_st1b_horiz : ArmSME_IntrStoreOp<"st1b.horiz">;
-def LLVM_aarch64_sme_st1h_horiz : ArmSME_IntrStoreOp<"st1h.horiz">;
-def LLVM_aarch64_sme_st1w_horiz : ArmSME_IntrStoreOp<"st1w.horiz">;
-def LLVM_aarch64_sme_st1d_horiz : ArmSME_IntrStoreOp<"st1d.horiz">;
-def LLVM_aarch64_sme_st1q_horiz : ArmSME_IntrStoreOp<"st1q.horiz">;
-def LLVM_aarch64_sme_st1b_vert : ArmSME_IntrStoreOp<"st1b.vert">;
-def LLVM_aarch64_sme_st1h_vert : ArmSME_IntrStoreOp<"st1h.vert">;
-def LLVM_aarch64_sme_st1w_vert : ArmSME_IntrStoreOp<"st1w.vert">;
-def LLVM_aarch64_sme_st1d_vert : ArmSME_IntrStoreOp<"st1d.vert">;
-def LLVM_aarch64_sme_st1q_vert : ArmSME_IntrStoreOp<"st1q.vert">;
-
-def LLVM_aarch64_sme_str
- : ArmSME_IntrOp<"str">,
- Arguments<(ins Arg<I32, "Index">,
- Arg<LLVM_AnyPointer, "Store address", [MemWrite]>)>;
-
-// Vector to tile slice
-class LLVM_aarch64_sme_write<string direction>
- : ArmSME_IntrOp<"write." # direction, /*overloadedOperands=*/[3],
- [AllShapesMatch<["pg", "vector"]>]>,
- Arguments<(ins Arg<I32, "Virtual tile ID">,
- Arg<I32, "Tile slice">,
- Arg<SVEPredicate, "Vector predicate">:$pg,
- Arg<SVEVector, "Vector operand">:$vector)>;
-
-// Tile slice to vector
-class LLVM_aarch64_sme_read<string direction>
- : ArmSME_IntrOp<"read." # direction, /*overloadedOperands=*/[],
- [AllShapesMatch<["vector", "pg", "res"]>,
- AllElementTypesMatch<["vector", "res"]>],
- /*numResults=*/1, /*overloadedResults=*/[0]>,
- Arguments<(ins Arg<SVEVector, "Vector operand">:$vector,
- Arg<SVEPredicate, "Vector predicate">:$pg,
- Arg<I32, "Virtual tile ID">,
- Arg<I32, "Tile slice">)>;
-
-def LLVM_aarch64_sme_write_horiz : LLVM_aarch64_sme_write<"horiz">;
-def LLVM_aarch64_sme_write_vert : LLVM_aarch64_sme_write<"vert">;
-
-def LLVM_aarch64_sme_read_horiz : LLVM_aarch64_sme_read<"horiz">;
-def LLVM_aarch64_sme_read_vert : LLVM_aarch64_sme_read<"vert">;
-
-def LLVM_aarch64_sme_za_enable : ArmSME_IntrOp<"za.enable">;
-def LLVM_aarch64_sme_za_disable : ArmSME_IntrOp<"za.disable">;
-
#endif // ARMSME_OPS
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEBase.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEBase.td
new file mode 100644
index 000000000000000..36f3ae44ff72a2e
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEBase.td
@@ -0,0 +1,52 @@
+//===-- ArmSMESMpBase.td - ArmSME dialect definitions -----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains the definition of the ArmSME dialect as well as some
+// shared definitions.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef ARMSME_OP_BASE
+#define ARMSME_OP_BASE
+
+include "mlir/IR/DialectBase.td"
+include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
+
+//===----------------------------------------------------------------------===//
+// ArmSME Dialect
+//===----------------------------------------------------------------------===//
+
+def ArmSME_Dialect : Dialect {
+ let name = "arm_sme";
+ let cppNamespace = "::mlir::arm_sme";
+ let summary = "Basic dialect to target Arm SME architectures";
+ let description = [{
+ This dialect contains the definitions necessary to target Arm SME
+ scalable matrix operations.
+
+ Sources:
+ https://developer.arm.com/documentation/ddi0616
+ https://developer.arm.com/documentation/ddi0602/2023-03/SME-Instructions
+ }];
+ let dependentDialects = ["scf::SCFDialect", "vector::VectorDialect",
+ "memref::MemRefDialect"];
+ let useDefaultAttributePrinterParser = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// ArmSME type definitions
+//===----------------------------------------------------------------------===//
+
+def SVEVector : ScalableVectorOfRankAndLengthAndType<
+ [1], [16, 8, 4, 2, 1], [I8, I16, I32, I64, I128, F16, BF16, F32, F64]>;
+
+def SVEPredicate : ScalableVectorOfRankAndLengthAndType<
+ [1], [16, 8, 4, 2, 1], [I1]>;
+
+
+#endif // ARMSME_OP_BASE
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
new file mode 100644
index 000000000000000..e3a051a171400eb
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
@@ -0,0 +1,137 @@
+//===-- ArmSMEIntrinsicsOps.td -----------------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains definitons of the intrinsic Ops for the ArmSME dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef ARMSME_INTRINSICS_OPS
+#define ARMSME_INTRINSICS_OPS
+
+include "ArmSMEBase.td"
+
+//===----------------------------------------------------------------------===//
+// ArmSME Intrinsic op definitions
+//===----------------------------------------------------------------------===//
+
+def MOPPredicate : ScalableVectorOfLengthAndType<[16, 8, 4, 2], [I1]>;
+def MOPVector : ScalableVectorOfLengthAndType<[16, 8, 4, 2],
+ [I8, I16, BF16, F16, F32, F64]>;
+def LDSTPredicate : ScalableVectorOfLengthAndType<[16, 8, 4, 2, 1], [I1]>;
+
+class ArmSME_IntrOp<string mnemonic, list<int> overloadedOperands = [],
+ list<Trait> traits = [], int numResults = 0,
+ list<int> overloadedResults = []>
+ : LLVM_IntrOpBase<
+ /*Dialect dialect=*/ArmSME_Dialect,
+ /*string opName=*/"intr." # mnemonic,
+ /*string enumName=*/"aarch64_sme_" # !subst(".", "_", mnemonic),
+ /*list<int> overloadedResults=*/overloadedResults,
+ /*list<int> overloadedOperands=*/overloadedOperands,
+ /*list<Trait> traits=*/traits,
+ /*int numResults=*/numResults>;
+
+// Zero
+def LLVM_aarch64_sme_zero : ArmSME_IntrOp<"zero">,
+ Arguments<(ins Arg<I32, "Tile mask">)>;
+
+// MOP's
+class ArmSME_IntrMopOverloadedOp<string mnemonic>
+ : ArmSME_IntrOp<mnemonic, [4]>,
+ Arguments<(ins Arg<I32, "Virtual tile ID">,
+ Arg<MOPPredicate, "LHS predicate">,
+ Arg<MOPPredicate, "RHS predicate">,
+ Arg<MOPVector, "LHS vector operand">,
+ Arg<MOPVector, "RHS vector operand">)>;
+
+def LLVM_aarch64_sme_mopa : ArmSME_IntrMopOverloadedOp<"mopa">;
+def LLVM_aarch64_sme_mops : ArmSME_IntrMopOverloadedOp<"mops">;
+def LLVM_aarch64_sme_mopa_wide : ArmSME_IntrMopOverloadedOp<"mopa.wide">;
+def LLVM_aarch64_sme_mops_wide : ArmSME_IntrMopOverloadedOp<"mops.wide">;
+def LLVM_aarch64_sme_smopa_wide : ArmSME_IntrMopOverloadedOp<"smopa.wide">;
+def LLVM_aarch64_sme_smops_wide : ArmSME_IntrMopOverloadedOp<"smops.wide">;
+def LLVM_aarch64_sme_umopa_wide : ArmSME_IntrMopOverloadedOp<"umopa.wide">;
+def LLVM_aarch64_sme_umops_wide : ArmSME_IntrMopOverloadedOp<"umops.wide">;
+def LLVM_aarch64_sme_sumopa_wide : ArmSME_IntrMopOverloadedOp<"sumopa.wide">;
+def LLVM_aarch64_sme_sumops_wide : ArmSME_IntrMopOverloadedOp<"sumops.wide">;
+def LLVM_aarch64_sme_usmopa_wide : ArmSME_IntrMopOverloadedOp<"usmopa.wide">;
+def LLVM_aarch64_sme_usmops_wide : ArmSME_IntrMopOverloadedOp<"usmops.wide">;
+
+// Loads
+class ArmSME_IntrLoadOp<string mnemonic>
+ : ArmSME_IntrOp<mnemonic>,
+ Arguments<(ins Arg<LDSTPredicate, "Vector predicate">,
+ Arg<LLVM_AnyPointer, "Load address">,
+ Arg<I32, "Virtual tile ID">,
+ Arg<I32, "Tile slice">)>;
+
+def LLVM_aarch64_sme_ld1b_horiz : ArmSME_IntrLoadOp<"ld1b.horiz">;
+def LLVM_aarch64_sme_ld1h_horiz : ArmSME_IntrLoadOp<"ld1h.horiz">;
+def LLVM_aarch64_sme_ld1w_horiz : ArmSME_IntrLoadOp<"ld1w.horiz">;
+def LLVM_aarch64_sme_ld1d_horiz : ArmSME_IntrLoadOp<"ld1d.horiz">;
+def LLVM_aarch64_sme_ld1q_horiz : ArmSME_IntrLoadOp<"ld1q.horiz">;
+def LLVM_aarch64_sme_ld1b_vert : ArmSME_IntrLoadOp<"ld1b.vert">;
+def LLVM_aarch64_sme_ld1h_vert : ArmSME_IntrLoadOp<"ld1h.vert">;
+def LLVM_aarch64_sme_ld1w_vert : ArmSME_IntrLoadOp<"ld1w.vert">;
+def LLVM_aarch64_sme_ld1d_vert : ArmSME_IntrLoadOp<"ld1d.vert">;
+def LLVM_aarch64_sme_ld1q_vert : ArmSME_IntrLoadOp<"ld1q.vert">;
+
+// Stores
+class ArmSME_IntrStoreOp<string mnemonic>
+ : ArmSME_IntrOp<mnemonic>,
+ Arguments<(ins Arg<LDSTPredicate, "Vector predicate">,
+ Arg<LLVM_AnyPointer, "Store address", [MemWrite]>,
+ Arg<I32, "Virtual tile ID">,
+ Arg<I32, "Tile slice">)>;
+
+def LLVM_aarch64_sme_st1b_horiz : ArmSME_IntrStoreOp<"st1b.horiz">;
+def LLVM_aarch64_sme_st1h_horiz : ArmSME_IntrStoreOp<"st1h.horiz">;
+def LLVM_aarch64_sme_st1w_horiz : ArmSME_IntrStoreOp<"st1w.horiz">;
+def LLVM_aarch64_sme_st1d_horiz : ArmSME_IntrStoreOp<"st1d.horiz">;
+def LLVM_aarch64_sme_st1q_horiz : ArmSME_IntrStoreOp<"st1q.horiz">;
+def LLVM_aarch64_sme_st1b_vert : ArmSME_IntrStoreOp<"st1b.vert">;
+def LLVM_aarch64_sme_st1h_vert : ArmSME_IntrStoreOp<"st1h.vert">;
+def LLVM_aarch64_sme_st1w_vert : ArmSME_IntrStoreOp<"st1w.vert">;
+def LLVM_aarch64_sme_st1d_vert : ArmSME_IntrStoreOp<"st1d.vert">;
+def LLVM_aarch64_sme_st1q_vert : ArmSME_IntrStoreOp<"st1q.vert">;
+
+def LLVM_aarch64_sme_str
+ : ArmSME_IntrOp<"str">,
+ Arguments<(ins Arg<I32, "Index">,
+ Arg<LLVM_AnyPointer, "Store address", [MemWrite]>)>;
+
+// Vector to tile slice
+class LLVM_aarch64_sme_write<string direction>
+ : ArmSME_IntrOp<"write." # direction, /*overloadedOperands=*/[3],
+ [AllShapesMatch<["pg", "vector"]>]>,
+ Arguments<(ins Arg<I32, "Virtual tile ID">,
+ Arg<I32, "Tile slice">,
+ Arg<SVEPredicate, "Vector predicate">:$pg,
+ Arg<SVEVector, "Vector operand">:$vector)>;
+
+// Tile slice to vector
+class LLVM_aarch64_sme_read<string direction>
+ : ArmSME_IntrOp<"read." # direction, /*overloadedOperands=*/[],
+ [AllShapesMatch<["vector", "pg", "res"]>,
+ AllElementTypesMatch<["vector", "res"]>],
+ /*numResults=*/1, /*overloadedResults=*/[0]>,
+ Arguments<(ins Arg<SVEVector, "Vector operand">:$vector,
+ Arg<SVEPredicate, "Vector predicate">:$pg,
+ Arg<I32, "Virtual tile ID">,
+ Arg<I32, "Tile slice">)>;
+
+def LLVM_aarch64_sme_write_horiz : LLVM_aarch64_sme_write<"horiz">;
+def LLVM_aarch64_sme_write_vert : LLVM_aarch64_sme_write<"vert">;
+
+def LLVM_aarch64_sme_read_horiz : LLVM_aarch64_sme_read<"horiz">;
+def LLVM_aarch64_sme_read_vert : LLVM_aarch64_sme_read<"vert">;
+
+def LLVM_aarch64_sme_za_enable : ArmSME_IntrOp<"za.enable">;
+def LLVM_aarch64_sme_za_disable : ArmSME_IntrOp<"za.disable">;
+
+#endif // ARMSME_INTRINSICS_OPS
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt
index 617809e482b2caa..62b61aa90246a01 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt
@@ -10,3 +10,10 @@ mlir_tablegen(ArmSMEEnums.cpp.inc -gen-enum-defs)
mlir_tablegen(ArmSMEAttrDefs.h.inc -gen-attrdef-decls -attrdefs-dialect=arm_sme)
mlir_tablegen(ArmSMEAttrDefs.cpp.inc -gen-attrdef-defs -attrdefs-dialect=arm_sme)
add_public_tablegen_target(MLIRArmSMEAttrDefsIncGen)
+
+# Generate declarations and definitions of ArmSME intrinsics
+set(LLVM_TARGET_DEFINITIONS ArmSMEIntrinsicOps.td)
+mlir_tablegen(ArmSMEIntrinsicOps.h.inc -gen-op-decls)
+mlir_tablegen(ArmSMEIntrinsicOps.cpp.inc -gen-op-defs)
+mlir_tablegen(ArmSMEIntrinsicConversions.inc -gen-llvmir-conversions)
+add_public_tablegen_target(ArmSMEIntrinsicOpsIncGen)
diff --git a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
index 101cb750f4a6f30..e7fce6a7fe6f1f5 100644
--- a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
@@ -31,6 +31,9 @@ using namespace mlir::arm_sme;
#define GET_OP_CLASSES
#include "mlir/Dialect/ArmSME/IR/ArmSME.cpp.inc"
+#define GET_OP_CLASSES
+#include "mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.cpp.inc"
+
#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/ArmSME/IR/ArmSMETypes.cpp.inc"
@@ -46,6 +49,9 @@ void ArmSMEDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/ArmSME/IR/ArmSME.cpp.inc"
+ ,
+#define GET_OP_LIST
+#include "mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.cpp.inc"
>();
}
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index e75e958e18a2cfd..51cbd5423ecf49b 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -187,10 +187,10 @@ struct LoadTileSliceToArmSMELowering
auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
auto tileI32 = castTileIDToI32(tile, loc, rewriter);
- arm_sme::TileSliceLayout layout = loadTileSliceOp.getLayout();
+ TileSliceLayout layout = loadTileSliceOp.getLayout();
// Create 'arm_sme.intr.ld1*.(horiz|vert)' intrin...
[truncated]
|
CI does not seem happy with this change 😅 |
Looks like a missing CMake dependency - I've sent an update that should fix it 🤞🏻 |
@@ -187,10 +187,10 @@ struct LoadTileSliceToArmSMELowering | |||
auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one); | |||
|
|||
auto tileI32 = castTileIDToI32(tile, loc, rewriter); | |||
arm_sme::TileSliceLayout layout = loadTileSliceOp.getLayout(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why has the namespace been removed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not needed (see using namespace mlir::arm_sme;
at the top), but this is an unrelated change and I will revert it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering if we should rename this to ArmSMEOps.td
? And the new .td file with the dialect could be ArmSME.td
or ArmSMEDialect.td
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That was my original plan, but was intending to do it in a follow-up. But may as well do it here. That will require a few CMake changes, but the noise level is pretty low.
✅ With the latest revision this PR passed the C/C++ code formatter. |
6c6704b
to
5a25f65
Compare
Move the definitions of LLVM intrinsic Ops for ArmSME into a dedicated file. To facilitate this, the dialect definition together with various shared definitions are moved to ArmSMEBase.td. This change will allow us to refactor the ArmSME dialect documentation. In particular, we will be able to categorise the Ops into "regular" and "intrinsic" ops. Also, it will be easier to add some custom documentation as opposed to relying on auto-generated docs that simply list the available Ops. The documentation will be updated in a forthcoming patch. Only non-functional changes.
Add missing CMake dependency
* Revert the removal of `arm_sme` namespace qualifier * Rename `ArmSME.td` as `ArmOpsSME.td` (file defining regular Ops) * Rename `ArmSMEBase.td` as `ArmSME.td` (main dialect file) * Fix typos
Fix clang-format issues
@@ -1,12 +1,24 @@ | |||
add_mlir_dialect(ArmSME arm_sme ArmSME) | |||
add_mlir_doc(ArmSME ArmSME Dialects/ -gen-dialect-doc -dialect=arm_sme) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that with these changes the docs will be empty as ArmSME.td
does not include the op definitions. This was the same problem the vector dialect had a little while ago.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#68110 ;-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No worries then 😄
1f3a53c
to
2f8c907
Compare
Update comments
mlir_tablegen(ArmSMEAttrDefs.h.inc -gen-attrdef-decls -attrdefs-dialect=arm_sme) | ||
mlir_tablegen(ArmSMEAttrDefs.cpp.inc -gen-attrdef-defs -attrdefs-dialect=arm_sme) | ||
add_public_tablegen_target(MLIRArmSMEAttrDefsIncGen) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess this is part of the MLIRArmSMEOpsIncGen
target now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed. Everything that's defined in ArmSMEOp.td should lead to one public target (unless there's a good reason not to).
Address comments from Cullen
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM cheers!
Move the definitions of LLVM intrinsic Ops for ArmSME into a dedicated file. To facilitate this, the dialect definition together with various shared definitions are moved to ArmSME.td. This change will allow us to refactor the ArmSME dialect documentation. In particular, we will be able to categorise the Ops into "regular" and "intrinsic" ops. Also, it will be easier to add some custom documentation as opposed to relying on auto-generated docs that simply list the available Ops. The documentation will be updated in a forthcoming patch. Only non-functional changes.
Thanks for the reviews - I merged this one manually: 8d6d4f8 (experimenting with different workflows). |
This patch introduces a hand-written markdown file that documents the ArmSME dialect. ATM, it simply includes the auto-generated documentation for the custom and for the LLVM intrinsic ops. Doing this by hand allows us to clarify the document by splitting it into meaningful categories. Moving forward, this change will allow us to improve/expand the documentation for the ArmSME dialect (and, in general, about supporting SME in MLIR). Depends on llvm#67985
Move the definitions of LLVM intrinsic Ops for ArmSME into a dedicated
file. To facilitate this, the dialect definition together with various
shared definitions are moved to ArmSMEBase.td.
This change will allow us to refactor the ArmSME dialect documentation.
In particular, we will be able to categorise the Ops into "regular" and
"intrinsic" ops. Also, it will be easier to add some custom
documentation as opposed to relying on auto-generated docs that simply
list the available Ops.
The documentation will be updated in a forthcoming patch. Only
non-functional changes.