Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][ArmSME] Split the Op definition (nfc) #67985

Closed
wants to merge 6 commits into from

Conversation

banach-space
Copy link
Contributor

Move the definitions of LLVM intrinsic Ops for ArmSME into a dedicated
file. To facilitate this, the dialect definition together with various
shared definitions are moved to ArmSMEBase.td.

This change will allow us to refactor the ArmSME dialect documentation.
In particular, we will be able to categorise the Ops into "regular" and
"intrinsic" ops. Also, it will be easier to add some custom
documentation as opposed to relying on auto-generated docs that simply
list the available Ops.

The documentation will be updated in a forthcoming patch. Only
non-functional changes.

@llvmbot
Copy link
Member

llvmbot commented Oct 2, 2023

@llvm/pr-subscribers-mlir-llvm
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-sme

Changes

Move the definitions of LLVM intrinsic Ops for ArmSME into a dedicated
file. To facilitate this, the dialect definition together with various
shared definitions are moved to ArmSMEBase.td.

This change will allow us to refactor the ArmSME dialect documentation.
In particular, we will be able to categorise the Ops into "regular" and
"intrinsic" ops. Also, it will be easier to add some custom
documentation as opposed to relying on auto-generated docs that simply
list the available Ops.

The documentation will be updated in a forthcoming patch. Only
non-functional changes.


Patch is 21.42 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/67985.diff

8 Files Affected:

  • (modified) mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h (+3)
  • (modified) mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td (+1-146)
  • (added) mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEBase.td (+52)
  • (added) mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td (+137)
  • (modified) mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt (+7)
  • (modified) mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp (+6)
  • (modified) mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp (+4-4)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.cpp (+1)
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
index f947fc8fe1631b8..dcb18be4f05ead4 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
@@ -31,4 +31,7 @@
 #define GET_OP_CLASSES
 #include "mlir/Dialect/ArmSME/IR/ArmSME.h.inc"
 
+#define GET_OP_CLASSES
+#include "mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.h.inc"
+
 #endif // MLIR_DIALECT_ARMSME_IR_ARMSME_H
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
index 66a432ea1b171e0..4d89f62e28e0ac4 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
@@ -14,33 +14,13 @@
 #ifndef ARMSME_OPS
 #define ARMSME_OPS
 
+include "ArmSMEBase.td"
 include "mlir/IR/EnumAttr.td"
 include "mlir/IR/OpBase.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 
-//===----------------------------------------------------------------------===//
-// ArmSME dialect definition
-//===----------------------------------------------------------------------===//
-
-def ArmSME_Dialect : Dialect {
-  let name = "arm_sme";
-  let cppNamespace = "::mlir::arm_sme";
-  let summary = "Basic dialect to target Arm SME architectures";
-  let description = [{
-    This dialect contains the definitions necessary to target Arm SME
-    scalable matrix operations.
-
-    Sources:
-    https://developer.arm.com/documentation/ddi0616
-    https://developer.arm.com/documentation/ddi0602/2023-03/SME-Instructions
-  }];
-  let dependentDialects = ["scf::SCFDialect", "vector::VectorDialect",
-                           "memref::MemRefDialect"];
-  let useDefaultAttributePrinterParser = 1;
-}
-
 //===----------------------------------------------------------------------===//
 // ArmSME type definitions
 //===----------------------------------------------------------------------===//
@@ -65,12 +45,6 @@ def nxnxv2f64  : SMETileType<F64,  [2,  2 ], "vector<[2]x[2]xf64>">;
 def SMETile : AnyTypeOf<[nxnxv16i8, nxnxv8i16, nxnxv4i32, nxnxv2i64, nxnxv1i128,
                          nxnxv8f16, nxnxv8bf16, nxnxv4f32, nxnxv2f64]>;
 
-def SVEVector : ScalableVectorOfRankAndLengthAndType<
-  [1], [16, 8, 4, 2, 1], [I8, I16, I32, I64, I128, F16, BF16, F32, F64]>;
-
-def SVEPredicate : ScalableVectorOfRankAndLengthAndType<
-  [1], [16, 8, 4, 2, 1], [I1]>;
-
 // A type constraint that verifies the bitwidth of the scalar integer returned
 // from 'arm_sme.get_tile_id' matches the element bitwidth of a "virtual tile".
 def TileElementWidthMatchesTileID : TypesMatchWith<
@@ -538,123 +512,4 @@ def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [Pure,
   }];
 }
 
-//===----------------------------------------------------------------------===//
-// ArmSME Intrinsic op definitions
-//===----------------------------------------------------------------------===//
-
-def MOPPredicate : ScalableVectorOfLengthAndType<[16, 8, 4, 2], [I1]>;
-def MOPVector : ScalableVectorOfLengthAndType<[16, 8, 4, 2],
-                                              [I8, I16, BF16, F16, F32, F64]>;
-def LDSTPredicate : ScalableVectorOfLengthAndType<[16, 8, 4, 2, 1], [I1]>;
-
-class ArmSME_IntrOp<string mnemonic, list<int> overloadedOperands = [],
-                    list<Trait> traits = [], int numResults = 0,
-                    list<int> overloadedResults = []>
-    : LLVM_IntrOpBase<
-          /*Dialect dialect=*/ArmSME_Dialect,
-          /*string opName=*/"intr." # mnemonic,
-          /*string enumName=*/"aarch64_sme_" # !subst(".", "_", mnemonic),
-          /*list<int> overloadedResults=*/overloadedResults,
-          /*list<int> overloadedOperands=*/overloadedOperands,
-          /*list<Trait> traits=*/traits,
-          /*int numResults=*/numResults>;
-
-// Zero
-def LLVM_aarch64_sme_zero : ArmSME_IntrOp<"zero">,
-                            Arguments<(ins Arg<I32, "Tile mask">)>;
-
-// MOP's
-class ArmSME_IntrMopOverloadedOp<string mnemonic>
-    : ArmSME_IntrOp<mnemonic, [4]>,
-      Arguments<(ins Arg<I32, "Virtual tile ID">,
-                 Arg<MOPPredicate, "LHS predicate">,
-                 Arg<MOPPredicate, "RHS predicate">,
-                 Arg<MOPVector, "LHS vector operand">,
-                 Arg<MOPVector, "RHS vector operand">)>;
-
-def LLVM_aarch64_sme_mopa : ArmSME_IntrMopOverloadedOp<"mopa">;
-def LLVM_aarch64_sme_mops : ArmSME_IntrMopOverloadedOp<"mops">;
-def LLVM_aarch64_sme_mopa_wide : ArmSME_IntrMopOverloadedOp<"mopa.wide">;
-def LLVM_aarch64_sme_mops_wide : ArmSME_IntrMopOverloadedOp<"mops.wide">;
-def LLVM_aarch64_sme_smopa_wide : ArmSME_IntrMopOverloadedOp<"smopa.wide">;
-def LLVM_aarch64_sme_smops_wide : ArmSME_IntrMopOverloadedOp<"smops.wide">;
-def LLVM_aarch64_sme_umopa_wide : ArmSME_IntrMopOverloadedOp<"umopa.wide">;
-def LLVM_aarch64_sme_umops_wide : ArmSME_IntrMopOverloadedOp<"umops.wide">;
-def LLVM_aarch64_sme_sumopa_wide : ArmSME_IntrMopOverloadedOp<"sumopa.wide">;
-def LLVM_aarch64_sme_sumops_wide : ArmSME_IntrMopOverloadedOp<"sumops.wide">;
-def LLVM_aarch64_sme_usmopa_wide : ArmSME_IntrMopOverloadedOp<"usmopa.wide">;
-def LLVM_aarch64_sme_usmops_wide : ArmSME_IntrMopOverloadedOp<"usmops.wide">;
-
-// Loads
-class ArmSME_IntrLoadOp<string mnemonic>
-    : ArmSME_IntrOp<mnemonic>,
-      Arguments<(ins Arg<LDSTPredicate, "Vector predicate">,
-                 Arg<LLVM_AnyPointer, "Load address">,
-                 Arg<I32, "Virtual tile ID">,
-                 Arg<I32, "Tile slice">)>;
-
-def LLVM_aarch64_sme_ld1b_horiz : ArmSME_IntrLoadOp<"ld1b.horiz">;
-def LLVM_aarch64_sme_ld1h_horiz : ArmSME_IntrLoadOp<"ld1h.horiz">;
-def LLVM_aarch64_sme_ld1w_horiz : ArmSME_IntrLoadOp<"ld1w.horiz">;
-def LLVM_aarch64_sme_ld1d_horiz : ArmSME_IntrLoadOp<"ld1d.horiz">;
-def LLVM_aarch64_sme_ld1q_horiz : ArmSME_IntrLoadOp<"ld1q.horiz">;
-def LLVM_aarch64_sme_ld1b_vert : ArmSME_IntrLoadOp<"ld1b.vert">;
-def LLVM_aarch64_sme_ld1h_vert : ArmSME_IntrLoadOp<"ld1h.vert">;
-def LLVM_aarch64_sme_ld1w_vert : ArmSME_IntrLoadOp<"ld1w.vert">;
-def LLVM_aarch64_sme_ld1d_vert : ArmSME_IntrLoadOp<"ld1d.vert">;
-def LLVM_aarch64_sme_ld1q_vert : ArmSME_IntrLoadOp<"ld1q.vert">;
-
-// Stores
-class ArmSME_IntrStoreOp<string mnemonic>
-    : ArmSME_IntrOp<mnemonic>,
-      Arguments<(ins Arg<LDSTPredicate, "Vector predicate">,
-                 Arg<LLVM_AnyPointer, "Store address", [MemWrite]>,
-                 Arg<I32, "Virtual tile ID">,
-                 Arg<I32, "Tile slice">)>;
-
-def LLVM_aarch64_sme_st1b_horiz : ArmSME_IntrStoreOp<"st1b.horiz">;
-def LLVM_aarch64_sme_st1h_horiz : ArmSME_IntrStoreOp<"st1h.horiz">;
-def LLVM_aarch64_sme_st1w_horiz : ArmSME_IntrStoreOp<"st1w.horiz">;
-def LLVM_aarch64_sme_st1d_horiz : ArmSME_IntrStoreOp<"st1d.horiz">;
-def LLVM_aarch64_sme_st1q_horiz : ArmSME_IntrStoreOp<"st1q.horiz">;
-def LLVM_aarch64_sme_st1b_vert : ArmSME_IntrStoreOp<"st1b.vert">;
-def LLVM_aarch64_sme_st1h_vert : ArmSME_IntrStoreOp<"st1h.vert">;
-def LLVM_aarch64_sme_st1w_vert : ArmSME_IntrStoreOp<"st1w.vert">;
-def LLVM_aarch64_sme_st1d_vert : ArmSME_IntrStoreOp<"st1d.vert">;
-def LLVM_aarch64_sme_st1q_vert : ArmSME_IntrStoreOp<"st1q.vert">;
-
-def LLVM_aarch64_sme_str
-    : ArmSME_IntrOp<"str">,
-      Arguments<(ins Arg<I32, "Index">,
-                 Arg<LLVM_AnyPointer, "Store address", [MemWrite]>)>;
-
-// Vector to tile slice
-class LLVM_aarch64_sme_write<string direction>
-    : ArmSME_IntrOp<"write." # direction, /*overloadedOperands=*/[3],
-                    [AllShapesMatch<["pg", "vector"]>]>,
-      Arguments<(ins Arg<I32, "Virtual tile ID">,
-                     Arg<I32, "Tile slice">,
-                     Arg<SVEPredicate, "Vector predicate">:$pg,
-                     Arg<SVEVector, "Vector operand">:$vector)>;
-
-// Tile slice to vector
-class LLVM_aarch64_sme_read<string direction>
-    : ArmSME_IntrOp<"read." # direction, /*overloadedOperands=*/[],
-                    [AllShapesMatch<["vector", "pg", "res"]>,
-                     AllElementTypesMatch<["vector", "res"]>],
-                    /*numResults=*/1, /*overloadedResults=*/[0]>,
-      Arguments<(ins Arg<SVEVector, "Vector operand">:$vector,
-                     Arg<SVEPredicate, "Vector predicate">:$pg,
-                     Arg<I32, "Virtual tile ID">,
-                     Arg<I32, "Tile slice">)>;
-
-def LLVM_aarch64_sme_write_horiz : LLVM_aarch64_sme_write<"horiz">;
-def LLVM_aarch64_sme_write_vert : LLVM_aarch64_sme_write<"vert">;
-
-def LLVM_aarch64_sme_read_horiz : LLVM_aarch64_sme_read<"horiz">;
-def LLVM_aarch64_sme_read_vert : LLVM_aarch64_sme_read<"vert">;
-
-def LLVM_aarch64_sme_za_enable : ArmSME_IntrOp<"za.enable">;
-def LLVM_aarch64_sme_za_disable : ArmSME_IntrOp<"za.disable">;
-
 #endif // ARMSME_OPS
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEBase.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEBase.td
new file mode 100644
index 000000000000000..36f3ae44ff72a2e
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEBase.td
@@ -0,0 +1,52 @@
+//===-- ArmSMESMpBase.td - ArmSME dialect definitions -----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains the definition of the ArmSME dialect as well as some
+// shared definitions.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef ARMSME_OP_BASE
+#define ARMSME_OP_BASE
+
+include "mlir/IR/DialectBase.td"
+include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
+
+//===----------------------------------------------------------------------===//
+// ArmSME Dialect
+//===----------------------------------------------------------------------===//
+
+def ArmSME_Dialect : Dialect {
+  let name = "arm_sme";
+  let cppNamespace = "::mlir::arm_sme";
+  let summary = "Basic dialect to target Arm SME architectures";
+  let description = [{
+    This dialect contains the definitions necessary to target Arm SME
+    scalable matrix operations.
+
+    Sources:
+    https://developer.arm.com/documentation/ddi0616
+    https://developer.arm.com/documentation/ddi0602/2023-03/SME-Instructions
+  }];
+  let dependentDialects = ["scf::SCFDialect", "vector::VectorDialect",
+                           "memref::MemRefDialect"];
+  let useDefaultAttributePrinterParser = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// ArmSME type definitions
+//===----------------------------------------------------------------------===//
+
+def SVEVector : ScalableVectorOfRankAndLengthAndType<
+  [1], [16, 8, 4, 2, 1], [I8, I16, I32, I64, I128, F16, BF16, F32, F64]>;
+
+def SVEPredicate : ScalableVectorOfRankAndLengthAndType<
+  [1], [16, 8, 4, 2, 1], [I1]>;
+
+
+#endif // ARMSME_OP_BASE
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
new file mode 100644
index 000000000000000..e3a051a171400eb
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
@@ -0,0 +1,137 @@
+//===-- ArmSMEIntrinsicsOps.td -----------------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains definitons of the intrinsic Ops for the ArmSME dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef ARMSME_INTRINSICS_OPS
+#define ARMSME_INTRINSICS_OPS
+
+include "ArmSMEBase.td"
+
+//===----------------------------------------------------------------------===//
+// ArmSME Intrinsic op definitions
+//===----------------------------------------------------------------------===//
+
+def MOPPredicate : ScalableVectorOfLengthAndType<[16, 8, 4, 2], [I1]>;
+def MOPVector : ScalableVectorOfLengthAndType<[16, 8, 4, 2],
+                                              [I8, I16, BF16, F16, F32, F64]>;
+def LDSTPredicate : ScalableVectorOfLengthAndType<[16, 8, 4, 2, 1], [I1]>;
+
+class ArmSME_IntrOp<string mnemonic, list<int> overloadedOperands = [],
+                    list<Trait> traits = [], int numResults = 0,
+                    list<int> overloadedResults = []>
+    : LLVM_IntrOpBase<
+          /*Dialect dialect=*/ArmSME_Dialect,
+          /*string opName=*/"intr." # mnemonic,
+          /*string enumName=*/"aarch64_sme_" # !subst(".", "_", mnemonic),
+          /*list<int> overloadedResults=*/overloadedResults,
+          /*list<int> overloadedOperands=*/overloadedOperands,
+          /*list<Trait> traits=*/traits,
+          /*int numResults=*/numResults>;
+
+// Zero
+def LLVM_aarch64_sme_zero : ArmSME_IntrOp<"zero">,
+                            Arguments<(ins Arg<I32, "Tile mask">)>;
+
+// MOP's
+class ArmSME_IntrMopOverloadedOp<string mnemonic>
+    : ArmSME_IntrOp<mnemonic, [4]>,
+      Arguments<(ins Arg<I32, "Virtual tile ID">,
+                 Arg<MOPPredicate, "LHS predicate">,
+                 Arg<MOPPredicate, "RHS predicate">,
+                 Arg<MOPVector, "LHS vector operand">,
+                 Arg<MOPVector, "RHS vector operand">)>;
+
+def LLVM_aarch64_sme_mopa : ArmSME_IntrMopOverloadedOp<"mopa">;
+def LLVM_aarch64_sme_mops : ArmSME_IntrMopOverloadedOp<"mops">;
+def LLVM_aarch64_sme_mopa_wide : ArmSME_IntrMopOverloadedOp<"mopa.wide">;
+def LLVM_aarch64_sme_mops_wide : ArmSME_IntrMopOverloadedOp<"mops.wide">;
+def LLVM_aarch64_sme_smopa_wide : ArmSME_IntrMopOverloadedOp<"smopa.wide">;
+def LLVM_aarch64_sme_smops_wide : ArmSME_IntrMopOverloadedOp<"smops.wide">;
+def LLVM_aarch64_sme_umopa_wide : ArmSME_IntrMopOverloadedOp<"umopa.wide">;
+def LLVM_aarch64_sme_umops_wide : ArmSME_IntrMopOverloadedOp<"umops.wide">;
+def LLVM_aarch64_sme_sumopa_wide : ArmSME_IntrMopOverloadedOp<"sumopa.wide">;
+def LLVM_aarch64_sme_sumops_wide : ArmSME_IntrMopOverloadedOp<"sumops.wide">;
+def LLVM_aarch64_sme_usmopa_wide : ArmSME_IntrMopOverloadedOp<"usmopa.wide">;
+def LLVM_aarch64_sme_usmops_wide : ArmSME_IntrMopOverloadedOp<"usmops.wide">;
+
+// Loads
+class ArmSME_IntrLoadOp<string mnemonic>
+    : ArmSME_IntrOp<mnemonic>,
+      Arguments<(ins Arg<LDSTPredicate, "Vector predicate">,
+                 Arg<LLVM_AnyPointer, "Load address">,
+                 Arg<I32, "Virtual tile ID">,
+                 Arg<I32, "Tile slice">)>;
+
+def LLVM_aarch64_sme_ld1b_horiz : ArmSME_IntrLoadOp<"ld1b.horiz">;
+def LLVM_aarch64_sme_ld1h_horiz : ArmSME_IntrLoadOp<"ld1h.horiz">;
+def LLVM_aarch64_sme_ld1w_horiz : ArmSME_IntrLoadOp<"ld1w.horiz">;
+def LLVM_aarch64_sme_ld1d_horiz : ArmSME_IntrLoadOp<"ld1d.horiz">;
+def LLVM_aarch64_sme_ld1q_horiz : ArmSME_IntrLoadOp<"ld1q.horiz">;
+def LLVM_aarch64_sme_ld1b_vert : ArmSME_IntrLoadOp<"ld1b.vert">;
+def LLVM_aarch64_sme_ld1h_vert : ArmSME_IntrLoadOp<"ld1h.vert">;
+def LLVM_aarch64_sme_ld1w_vert : ArmSME_IntrLoadOp<"ld1w.vert">;
+def LLVM_aarch64_sme_ld1d_vert : ArmSME_IntrLoadOp<"ld1d.vert">;
+def LLVM_aarch64_sme_ld1q_vert : ArmSME_IntrLoadOp<"ld1q.vert">;
+
+// Stores
+class ArmSME_IntrStoreOp<string mnemonic>
+    : ArmSME_IntrOp<mnemonic>,
+      Arguments<(ins Arg<LDSTPredicate, "Vector predicate">,
+                 Arg<LLVM_AnyPointer, "Store address", [MemWrite]>,
+                 Arg<I32, "Virtual tile ID">,
+                 Arg<I32, "Tile slice">)>;
+
+def LLVM_aarch64_sme_st1b_horiz : ArmSME_IntrStoreOp<"st1b.horiz">;
+def LLVM_aarch64_sme_st1h_horiz : ArmSME_IntrStoreOp<"st1h.horiz">;
+def LLVM_aarch64_sme_st1w_horiz : ArmSME_IntrStoreOp<"st1w.horiz">;
+def LLVM_aarch64_sme_st1d_horiz : ArmSME_IntrStoreOp<"st1d.horiz">;
+def LLVM_aarch64_sme_st1q_horiz : ArmSME_IntrStoreOp<"st1q.horiz">;
+def LLVM_aarch64_sme_st1b_vert : ArmSME_IntrStoreOp<"st1b.vert">;
+def LLVM_aarch64_sme_st1h_vert : ArmSME_IntrStoreOp<"st1h.vert">;
+def LLVM_aarch64_sme_st1w_vert : ArmSME_IntrStoreOp<"st1w.vert">;
+def LLVM_aarch64_sme_st1d_vert : ArmSME_IntrStoreOp<"st1d.vert">;
+def LLVM_aarch64_sme_st1q_vert : ArmSME_IntrStoreOp<"st1q.vert">;
+
+def LLVM_aarch64_sme_str
+    : ArmSME_IntrOp<"str">,
+      Arguments<(ins Arg<I32, "Index">,
+                 Arg<LLVM_AnyPointer, "Store address", [MemWrite]>)>;
+
+// Vector to tile slice
+class LLVM_aarch64_sme_write<string direction>
+    : ArmSME_IntrOp<"write." # direction, /*overloadedOperands=*/[3],
+                    [AllShapesMatch<["pg", "vector"]>]>,
+      Arguments<(ins Arg<I32, "Virtual tile ID">,
+                     Arg<I32, "Tile slice">,
+                     Arg<SVEPredicate, "Vector predicate">:$pg,
+                     Arg<SVEVector, "Vector operand">:$vector)>;
+
+// Tile slice to vector
+class LLVM_aarch64_sme_read<string direction>
+    : ArmSME_IntrOp<"read." # direction, /*overloadedOperands=*/[],
+                    [AllShapesMatch<["vector", "pg", "res"]>,
+                     AllElementTypesMatch<["vector", "res"]>],
+                    /*numResults=*/1, /*overloadedResults=*/[0]>,
+      Arguments<(ins Arg<SVEVector, "Vector operand">:$vector,
+                     Arg<SVEPredicate, "Vector predicate">:$pg,
+                     Arg<I32, "Virtual tile ID">,
+                     Arg<I32, "Tile slice">)>;
+
+def LLVM_aarch64_sme_write_horiz : LLVM_aarch64_sme_write<"horiz">;
+def LLVM_aarch64_sme_write_vert : LLVM_aarch64_sme_write<"vert">;
+
+def LLVM_aarch64_sme_read_horiz : LLVM_aarch64_sme_read<"horiz">;
+def LLVM_aarch64_sme_read_vert : LLVM_aarch64_sme_read<"vert">;
+
+def LLVM_aarch64_sme_za_enable : ArmSME_IntrOp<"za.enable">;
+def LLVM_aarch64_sme_za_disable : ArmSME_IntrOp<"za.disable">;
+
+#endif // ARMSME_INTRINSICS_OPS
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt
index 617809e482b2caa..62b61aa90246a01 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt
@@ -10,3 +10,10 @@ mlir_tablegen(ArmSMEEnums.cpp.inc -gen-enum-defs)
 mlir_tablegen(ArmSMEAttrDefs.h.inc -gen-attrdef-decls -attrdefs-dialect=arm_sme)
 mlir_tablegen(ArmSMEAttrDefs.cpp.inc -gen-attrdef-defs -attrdefs-dialect=arm_sme)
 add_public_tablegen_target(MLIRArmSMEAttrDefsIncGen)
+
+# Generate declarations and definitions of ArmSME intrinsics
+set(LLVM_TARGET_DEFINITIONS ArmSMEIntrinsicOps.td)
+mlir_tablegen(ArmSMEIntrinsicOps.h.inc -gen-op-decls)
+mlir_tablegen(ArmSMEIntrinsicOps.cpp.inc -gen-op-defs)
+mlir_tablegen(ArmSMEIntrinsicConversions.inc -gen-llvmir-conversions)
+add_public_tablegen_target(ArmSMEIntrinsicOpsIncGen)
diff --git a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
index 101cb750f4a6f30..e7fce6a7fe6f1f5 100644
--- a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
@@ -31,6 +31,9 @@ using namespace mlir::arm_sme;
 #define GET_OP_CLASSES
 #include "mlir/Dialect/ArmSME/IR/ArmSME.cpp.inc"
 
+#define GET_OP_CLASSES
+#include "mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.cpp.inc"
+
 #define GET_TYPEDEF_CLASSES
 #include "mlir/Dialect/ArmSME/IR/ArmSMETypes.cpp.inc"
 
@@ -46,6 +49,9 @@ void ArmSMEDialect::initialize() {
   addOperations<
 #define GET_OP_LIST
 #include "mlir/Dialect/ArmSME/IR/ArmSME.cpp.inc"
+      ,
+#define GET_OP_LIST
+#include "mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.cpp.inc"
       >();
 }
 
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index e75e958e18a2cfd..51cbd5423ecf49b 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -187,10 +187,10 @@ struct LoadTileSliceToArmSMELowering
     auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
 
     auto tileI32 = castTileIDToI32(tile, loc, rewriter);
-    arm_sme::TileSliceLayout layout = loadTileSliceOp.getLayout();
+    TileSliceLayout layout = loadTileSliceOp.getLayout();
 
     // Create 'arm_sme.intr.ld1*.(horiz|vert)' intrin...
[truncated]

@MacDue
Copy link
Member

MacDue commented Oct 2, 2023

CI does not seem happy with this change 😅

@banach-space
Copy link
Contributor Author

CI does not seem happy with this change 😅

Looks like a missing CMake dependency - I've sent an update that should fix it 🤞🏻

@@ -187,10 +187,10 @@ struct LoadTileSliceToArmSMELowering
auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);

auto tileI32 = castTileIDToI32(tile, loc, rewriter);
arm_sme::TileSliceLayout layout = loadTileSliceOp.getLayout();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why has the namespace been removed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not needed (see using namespace mlir::arm_sme; at the top), but this is an unrelated change and I will revert it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if we should rename this to ArmSMEOps.td? And the new .td file with the dialect could be ArmSME.td or ArmSMEDialect.td?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was my original plan, but was intending to do it in a follow-up. But may as well do it here. That will require a few CMake changes, but the noise level is pretty low.

mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEBase.td Outdated Show resolved Hide resolved
mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td Outdated Show resolved Hide resolved
@github-actions
Copy link

github-actions bot commented Oct 3, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

Move the definitions of LLVM intrinsic Ops for ArmSME into a dedicated
file. To facilitate this, the dialect definition together with various
shared definitions are moved to ArmSMEBase.td.

This change will allow us to refactor the ArmSME dialect documentation.
In particular, we will be able to categorise the Ops into "regular"  and
"intrinsic" ops. Also, it will be easier to add some custom
documentation as opposed to relying on auto-generated docs that simply
list the available Ops.

The documentation will be updated in a forthcoming patch. Only
non-functional changes.
* Revert the removal of `arm_sme` namespace qualifier
* Rename `ArmSME.td` as `ArmOpsSME.td` (file defining regular Ops)
* Rename `ArmSMEBase.td` as `ArmSME.td` (main dialect file)
* Fix typos
@@ -1,12 +1,24 @@
add_mlir_dialect(ArmSME arm_sme ArmSME)
add_mlir_doc(ArmSME ArmSME Dialects/ -gen-dialect-doc -dialect=arm_sme)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that with these changes the docs will be empty as ArmSME.td does not include the op definitions. This was the same problem the vector dialect had a little while ago.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#68110 ;-)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No worries then 😄

@banach-space banach-space force-pushed the andrzej/split_op_defs branch from 1f3a53c to 2f8c907 Compare October 3, 2023 14:32
mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td Outdated Show resolved Hide resolved
mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td Outdated Show resolved Hide resolved
mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td Outdated Show resolved Hide resolved
mlir_tablegen(ArmSMEAttrDefs.h.inc -gen-attrdef-decls -attrdefs-dialect=arm_sme)
mlir_tablegen(ArmSMEAttrDefs.cpp.inc -gen-attrdef-defs -attrdefs-dialect=arm_sme)
add_public_tablegen_target(MLIRArmSMEAttrDefsIncGen)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this is part of the MLIRArmSMEOpsIncGen target now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed. Everything that's defined in ArmSMEOp.td should lead to one public target (unless there's a good reason not to).

mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt Outdated Show resolved Hide resolved
Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM cheers!

banach-space added a commit that referenced this pull request Oct 4, 2023
Move the definitions of LLVM intrinsic Ops for ArmSME into a dedicated
file. To facilitate this, the dialect definition together with various
shared definitions are moved to ArmSME.td.

This change will allow us to refactor the ArmSME dialect documentation.
In particular, we will be able to categorise the Ops into "regular"  and
"intrinsic" ops. Also, it will be easier to add some custom
documentation as opposed to relying on auto-generated docs that simply
list the available Ops.

The documentation will be updated in a forthcoming patch. Only
non-functional changes.
@banach-space
Copy link
Contributor Author

Thanks for the reviews - I merged this one manually: 8d6d4f8 (experimenting with different workflows).

banach-space added a commit to banach-space/llvm-project that referenced this pull request Oct 4, 2023
This patch introduces a hand-written markdown file that documents the
ArmSME dialect. ATM, it simply includes the auto-generated documentation
for the custom and for the LLVM intrinsic ops. Doing this by hand allows
us to clarify the document by splitting it into meaningful categories.

Moving forward, this change will allow us to improve/expand the
documentation for the ArmSME dialect (and, in general, about supporting
SME in MLIR).

Depends on llvm#67985
@MacDue
Copy link
Member

MacDue commented Oct 4, 2023

Thanks for the reviews - I merged this one manually: 8d6d4f8 (experimenting with different workflows).

P.s. If you had put "Resolves #67985" in your commit message this PR would have been automatically closed on commit.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants