Skip to content

[mlir][amx] Simplify intrinsic generation #140559

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 23, 2025
Merged

Conversation

adam-smnk
Copy link
Contributor

Replaces separate amx named intrinsic operations with direct calls to LLVM intrinsic functions.
The existing amx tests are updated and expanded.

The separate conversion step translating amx intrinsics into LLVM IR is eliminated. Instead, this step is now performed by the existing llvm dialect infrastructure.

Related RFC: https://discourse.llvm.org/t/rfc-simplify-x86-intrinsic-generation/85581/7

@llvmbot
Copy link
Member

llvmbot commented May 19, 2025

@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-mlir-amx

Author: Adam Siemieniuk (adam-smnk)

Changes

Replaces separate amx named intrinsic operations with direct calls to LLVM intrinsic functions.
The existing amx tests are updated and expanded.

The separate conversion step translating amx intrinsics into LLVM IR is eliminated. Instead, this step is now performed by the existing llvm dialect infrastructure.

Related RFC: https://discourse.llvm.org/t/rfc-simplify-x86-intrinsic-generation/85581/7


Patch is 51.27 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/140559.diff

18 Files Affected:

  • (modified) mlir/include/mlir/Dialect/AMX/AMX.td (+71-86)
  • (modified) mlir/include/mlir/Dialect/AMX/AMXDialect.h (+4)
  • (added) mlir/include/mlir/Dialect/AMX/AMXInterfaces.td (+31)
  • (modified) mlir/include/mlir/Dialect/AMX/CMakeLists.txt (+2-3)
  • (modified) mlir/include/mlir/Dialect/AMX/Transforms.h (-3)
  • (modified) mlir/include/mlir/InitAllExtensions.h (-2)
  • (removed) mlir/include/mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h (-31)
  • (modified) mlir/include/mlir/Target/LLVMIR/Dialect/All.h (-2)
  • (modified) mlir/lib/Dialect/AMX/IR/AMXDialect.cpp (+188-2)
  • (modified) mlir/lib/Dialect/AMX/IR/CMakeLists.txt (+1)
  • (modified) mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt (-3)
  • (modified) mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp (+21-203)
  • (modified) mlir/lib/Target/LLVMIR/CMakeLists.txt (-1)
  • (removed) mlir/lib/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.cpp (-56)
  • (removed) mlir/lib/Target/LLVMIR/Dialect/AMX/CMakeLists.txt (-16)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt (-1)
  • (modified) mlir/test/Dialect/AMX/legalize-for-llvm.mlir (+27-27)
  • (modified) mlir/test/Target/LLVMIR/amx.mlir (+87-10)
diff --git a/mlir/include/mlir/Dialect/AMX/AMX.td b/mlir/include/mlir/Dialect/AMX/AMX.td
index 8a51df1ea183f..a484f2ca009a2 100644
--- a/mlir/include/mlir/Dialect/AMX/AMX.td
+++ b/mlir/include/mlir/Dialect/AMX/AMX.td
@@ -25,10 +25,11 @@
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef AMX
-#define AMX
+#ifndef AMX_OPS
+#define AMX_OPS
 
 include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
+include "mlir/Dialect/AMX/AMXInterfaces.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/AttrTypeBase.td"
 include "mlir/IR/BuiltinTypes.td"
@@ -47,8 +48,6 @@ def AMX_Dialect : Dialect {
 
     This `AMX` dialect provides a bridge between MLIR concepts such as
     vectors and memrefs and the lower level LLVM IR support of AMX.
-    The dialect is split into user-facing AMX ops (AMX_Op) and
-    backend-facing intrinsic ops (AMX_IntrOp).
 
     Note that since configuration changes (implicit at dialect level) are
     costly, it is highly recommended to use the AMX dialect on same-shaped
@@ -135,21 +134,17 @@ def AMXTileI8 : AMXTileOf<[I8]>;
 class AMX_Op<string mnemonic, list<Trait> traits = []> :
   Op<AMX_Dialect, mnemonic, traits> {}
 
-// The "internal" intrinsics are meant for compiler usage.
-class AMX_IntrOp<string mnemonic, int numResults, list<Trait> traits = []> :
-  LLVM_IntrOpBase<AMX_Dialect, mnemonic,
-                  "x86_" # !subst(".", "_", mnemonic) # "_internal",
-                  [], [], traits, numResults>;
-
 //===----------------------------------------------------------------------===//
-// AMX Op definitions (user facing).
+// AMX Op definitions
 //===----------------------------------------------------------------------===//
 
 //
 // Tile reset.
 //
 
-def TileZeroOp : AMX_Op<"tile_zero", [Pure]> {
+def TileZeroOp : AMX_Op<"tile_zero", [Pure,
+    AMXIntrinsicOpInterface
+  ]> {
   let summary = "tile zero operation";
   let description = [{
     Zeroes the destination tile, with the shape defined by the 2-dim
@@ -167,6 +162,14 @@ def TileZeroOp : AMX_Op<"tile_zero", [Pure]> {
     TileType getTileType() {
       return ::llvm::cast<TileType>(getRes().getType());
     }
+
+    std::string getIntrinsicName() {
+      return "llvm.x86.tilezero.internal";
+    }
+    SmallVector<Value> getIntrinsicOperands(
+        ::mlir::ArrayRef<Value> operands,
+        const ::mlir::LLVMTypeConverter &typeConverter,
+        ::mlir::RewriterBase &rewriter);
   }];
   let assemblyFormat = "attr-dict `:` qualified(type($res))";
   let hasVerifier = 1;
@@ -176,7 +179,9 @@ def TileZeroOp : AMX_Op<"tile_zero", [Pure]> {
 // Tile memory operations.
 //
 
-def TileLoadOp : AMX_Op<"tile_load", [Pure]> {
+def TileLoadOp : AMX_Op<"tile_load", [Pure,
+    AMXIntrinsicOpInterface
+  ]> {
   let summary = "tile load operation";
   let description = [{
     Loads a tile from memory defined by a base and indices, with the
@@ -200,13 +205,23 @@ def TileLoadOp : AMX_Op<"tile_load", [Pure]> {
     TileType getTileType() {
       return ::llvm::cast<TileType>(getRes().getType());
     }
+
+    std::string getIntrinsicName() {
+      return "llvm.x86.tileloadd64.internal";
+    }
+    SmallVector<Value> getIntrinsicOperands(
+        ::mlir::ArrayRef<Value> operands,
+        const ::mlir::LLVMTypeConverter &typeConverter,
+        ::mlir::RewriterBase &rewriter);
   }];
   let assemblyFormat = "$base `[` $indices `]` attr-dict `:` "
                        "type($base) `into` qualified(type($res))";
   let hasVerifier = 1;
 }
 
-def TileStoreOp : AMX_Op<"tile_store"> {
+def TileStoreOp : AMX_Op<"tile_store", [
+    AMXIntrinsicOpInterface
+  ]> {
   let summary = "tile store operation";
   let description = [{
     Stores a tile to memory defined by a base and indices, with the
@@ -230,6 +245,14 @@ def TileStoreOp : AMX_Op<"tile_store"> {
     TileType getTileType() {
       return ::llvm::cast<TileType>(getVal().getType());
     }
+
+    std::string getIntrinsicName() {
+      return "llvm.x86.tilestored64.internal";
+    }
+    SmallVector<Value> getIntrinsicOperands(
+        ::mlir::ArrayRef<Value> operands,
+        const ::mlir::LLVMTypeConverter &typeConverter,
+        ::mlir::RewriterBase &rewriter);
   }];
   let assemblyFormat = "$base `[` $indices `]` `,` $val attr-dict `:` "
                        "type($base) `,` qualified(type($val))";
@@ -240,8 +263,10 @@ def TileStoreOp : AMX_Op<"tile_store"> {
 // Tile arithmetic operations.
 //
 
-def TileMulFOp : AMX_Op<"tile_mulf", [
-    Pure, AllTypesMatch<["acc", "res"]>]> {
+def TileMulFOp : AMX_Op<"tile_mulf", [Pure,
+    AMXIntrinsicOpInterface,
+    AllTypesMatch<["acc", "res"]>
+  ]> {
   let summary = "tile multiplication operation (floating-point)";
   let description = [{
     Multiplies a "m x k" tile with a "k x n" tile and accumulates the results
@@ -270,6 +295,19 @@ def TileMulFOp : AMX_Op<"tile_mulf", [
     TileType getTileType() {
       return ::llvm::cast<TileType>(getRes().getType());
     }
+
+    std::string getIntrinsicName() {
+      std::string intr = "llvm.x86.tdp";
+      auto elementType =
+        getLhsTileType().getElementType();
+      intr += elementType.isF16() ? "fp16" : "bf16";
+      intr += "ps.internal";
+      return intr;
+    }
+    SmallVector<Value> getIntrinsicOperands(
+        ::mlir::ArrayRef<Value> operands,
+        const ::mlir::LLVMTypeConverter &typeConverter,
+        ::mlir::RewriterBase &rewriter);
   }];
   let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` "
                        "qualified(type($lhs)) `,` qualified(type($rhs))"
@@ -277,8 +315,10 @@ def TileMulFOp : AMX_Op<"tile_mulf", [
   let hasVerifier = 1;
 }
 
-def TileMulIOp : AMX_Op<"tile_muli", [
-    Pure, AllTypesMatch<["acc", "res"]>]> {
+def TileMulIOp : AMX_Op<"tile_muli", [Pure,
+    AMXIntrinsicOpInterface,
+    AllTypesMatch<["acc", "res"]>
+  ]> {
   let summary = "tile multiplication operation (integer)";
   let description = [{
     Multiplies a "m x k" tile with a "k x n" tile and accumulates the results
@@ -313,77 +353,22 @@ def TileMulIOp : AMX_Op<"tile_muli", [
     TileType getTileType() {
       return ::llvm::cast<TileType>(getRes().getType());
     }
+
+    std::string getIntrinsicName() {
+      std::string intr = "llvm.x86.tdpb";
+      intr += getIsZextLhs() ? "u" : "s";
+      intr += getIsZextRhs() ? "u" : "s";
+      intr += "d.internal";
+      return intr;
+    }
+    SmallVector<Value> getIntrinsicOperands(
+        ::mlir::ArrayRef<Value> operands,
+        const ::mlir::LLVMTypeConverter &typeConverter,
+        ::mlir::RewriterBase &rewriter);
   }];
   let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` "
                        "qualified(type($lhs)) `,` qualified(type($rhs)) `,` qualified(type($acc)) ";
   let hasVerifier = 1;
 }
 
-//===----------------------------------------------------------------------===//
-// AMX IntrOp definitions (LLVM compiler facing).
-//===----------------------------------------------------------------------===//
-
-//
-// Tile reset. Parameters define the tile size.
-//
-
-def LLVM_x86_amx_tilezero : AMX_IntrOp<"tilezero", 1>,
-  Arguments<(ins AnyInteger, AnyInteger)>;
-
-//
-// Tile memory operations. Parameters define the tile size,
-// base address, and stride between consecutive rows for the
-// memory operation.
-//
-
-def LLVM_x86_amx_tileloadd64 : AMX_IntrOp<"tileloadd64", 1>,
-  Arguments<(ins AnyInteger,
-                 AnyInteger, LLVM_AnyPointer, AnyInteger)>;
-
-def LLVM_x86_amx_tilestored64 : AMX_IntrOp<"tilestored64", 0>,
-  Arguments<(ins AnyInteger,
-                 AnyInteger, LLVM_AnyPointer, AnyInteger, LLVM_Type)>;
-
-//
-// Tile multiplication operations (series of dot products). Parameters
-// define the tile sizes and source and destination tiles for the
-// operation. Note that the prefix "tdp" stands for tile dot product.
-//
-
-// Dot product of bf16 tiles into f32 tile.
-def LLVM_x86_amx_tdpbf16ps : AMX_IntrOp<"tdpbf16ps", 1>,
-  Arguments<(ins AnyInteger,
-                 AnyInteger,
-		 AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
-
-// Dot product of f16 tiles into f32 tile.
-def LLVM_x86_amx_tdpfp16ps : AMX_IntrOp<"tdpfp16ps", 1>,
-  Arguments<(ins AnyInteger,
-                 AnyInteger,
-		 AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
-
-// Dot product of i8 tiles into i32 tile (with sign/sign extension).
-def LLVM_x86_amx_tdpbssd : AMX_IntrOp<"tdpbssd", 1>,
-  Arguments<(ins AnyInteger,
-                 AnyInteger,
-		 AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
-
-// Dot product of i8 tiles into i32 tile (with sign/zero extension).
-def LLVM_x86_amx_tdpbsud : AMX_IntrOp<"tdpbsud", 1>,
-  Arguments<(ins AnyInteger,
-                 AnyInteger,
-		 AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
-
-// Dot product of i8 tiles into i32 tile (with zero/sign extension).
-def LLVM_x86_amx_tdpbusd : AMX_IntrOp<"tdpbusd", 1>,
-  Arguments<(ins AnyInteger,
-                 AnyInteger,
-		 AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
-
-// Dot product of i8 tiles into i32 tile (with zero/zero extension).
-def LLVM_x86_amx_tdpbuud : AMX_IntrOp<"tdpbuud", 1>,
-  Arguments<(ins AnyInteger,
-                 AnyInteger,
-		 AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
-
-#endif // AMX
+#endif // AMX_OPS
diff --git a/mlir/include/mlir/Dialect/AMX/AMXDialect.h b/mlir/include/mlir/Dialect/AMX/AMXDialect.h
index c0553ad8733fd..c79f31d4c994a 100644
--- a/mlir/include/mlir/Dialect/AMX/AMXDialect.h
+++ b/mlir/include/mlir/Dialect/AMX/AMXDialect.h
@@ -14,11 +14,15 @@
 #define MLIR_DIALECT_AMX_AMXDIALECT_H_
 
 #include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/LLVMIR/LLVMInterfaces.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
+/// Include the generated interface declarations.
+#include "mlir/Dialect/AMX/AMXInterfaces.h.inc"
+
 #include "mlir/Dialect/AMX/AMXDialect.h.inc"
 
 #define GET_TYPEDEF_CLASSES
diff --git a/mlir/include/mlir/Dialect/AMX/AMXInterfaces.td b/mlir/include/mlir/Dialect/AMX/AMXInterfaces.td
new file mode 100644
index 0000000000000..012d1ba7368f7
--- /dev/null
+++ b/mlir/include/mlir/Dialect/AMX/AMXInterfaces.td
@@ -0,0 +1,31 @@
+//===- AMXInterfaces.td - AMX interfaces -------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines interfaces for the AMX dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef AMX_INTERFACES
+#define AMX_INTERFACES
+
+include "mlir/IR/Interfaces.td"
+include "mlir/Dialect/LLVMIR/LLVMInterfaces.td"
+
+//===----------------------------------------------------------------------===//
+// AMX Intrinsic Interface
+//===----------------------------------------------------------------------===//
+
+def AMXIntrinsicOpInterface
+    : OpInterface<"AMXIntrinsicOp", [OneToOneIntrinsicOpInterface]> {
+  let description = [{
+    A wrapper interface for operations representing AMX LLVM intrinsics.
+  }];
+  let cppNamespace = "::mlir::amx";
+}
+
+#endif // AMX_INTERFACES
diff --git a/mlir/include/mlir/Dialect/AMX/CMakeLists.txt b/mlir/include/mlir/Dialect/AMX/CMakeLists.txt
index f3f1aff5a6360..f875c78d240cc 100644
--- a/mlir/include/mlir/Dialect/AMX/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/AMX/CMakeLists.txt
@@ -1,6 +1,5 @@
 add_mlir_dialect(AMX amx)
 add_mlir_doc(AMX AMX Dialects/ -gen-dialect-doc -dialect=amx)
 
-set(LLVM_TARGET_DEFINITIONS AMX.td)
-mlir_tablegen(AMXConversions.inc -gen-llvmir-conversions)
-add_public_tablegen_target(MLIRAMXConversionsIncGen)
+add_mlir_interface(AMXInterfaces)
+add_dependencies(MLIRAMXIncGen MLIRAMXInterfacesIncGen)
diff --git a/mlir/include/mlir/Dialect/AMX/Transforms.h b/mlir/include/mlir/Dialect/AMX/Transforms.h
index 7391ec2ff6b14..4a751d99ceeee 100644
--- a/mlir/include/mlir/Dialect/AMX/Transforms.h
+++ b/mlir/include/mlir/Dialect/AMX/Transforms.h
@@ -25,9 +25,6 @@ void populateAMXLegalizeForLLVMExportPatterns(LLVMTypeConverter &converter,
 /// intrinsics.
 void configureAMXLegalizeForExportTarget(LLVMConversionTarget &target);
 
-/// Register LLVM conversion interface for AMX dialect.
-void registerConvertAMXToLLVMInterface(DialectRegistry &registry);
-
 } // namespace mlir
 
 #endif // MLIR_DIALECT_AMX_TRANSFORMS_H
diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h
index 37e4904cb48ed..1e3f7c649a8bd 100644
--- a/mlir/include/mlir/InitAllExtensions.h
+++ b/mlir/include/mlir/InitAllExtensions.h
@@ -32,7 +32,6 @@
 #include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h"
 #include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
-#include "mlir/Dialect/AMX/Transforms.h"
 #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
 #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
 #include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h"
@@ -84,7 +83,6 @@ inline void registerAllExtensions(DialectRegistry &registry) {
   registerConvertOpenMPToLLVMInterface(registry);
   registerConvertSCFToEmitCInterface(registry);
   ub::registerConvertUBToLLVMInterface(registry);
-  registerConvertAMXToLLVMInterface(registry);
   gpu::registerConvertGpuToLLVMInterface(registry);
   NVVM::registerConvertGpuToNVVMInterface(registry);
   vector::registerConvertVectorToLLVMInterface(registry);
diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h b/mlir/include/mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h
deleted file mode 100644
index 4525ec3212196..0000000000000
--- a/mlir/include/mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h
+++ /dev/null
@@ -1,31 +0,0 @@
-//===- AMXToLLVMIRTranslation.h - AMX to LLVM IR ----------------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This provides registration calls for AMX dialect to LLVM IR translation.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_TARGET_LLVMIR_DIALECT_AMX_AMXTOLLVMIRTRANSLATION_H
-#define MLIR_TARGET_LLVMIR_DIALECT_AMX_AMXTOLLVMIRTRANSLATION_H
-
-namespace mlir {
-
-class DialectRegistry;
-class MLIRContext;
-
-/// Register the AMX dialect and the translation from it to the LLVM IR
-/// in the given registry;
-void registerAMXDialectTranslation(DialectRegistry &registry);
-
-/// Register the AMX dialect and the translation from it in the registry
-/// associated with the given context.
-void registerAMXDialectTranslation(MLIRContext &context);
-
-} // namespace mlir
-
-#endif // MLIR_TARGET_LLVMIR_DIALECT_AMX_AMXTOLLVMIRTRANSLATION_H
diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h
index e043ff2f6825c..60615cf601655 100644
--- a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h
+++ b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h
@@ -14,7 +14,6 @@
 #ifndef MLIR_TARGET_LLVMIR_DIALECT_ALL_H
 #define MLIR_TARGET_LLVMIR_DIALECT_ALL_H
 
-#include "mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/Dialect/ArmNeon/ArmNeonToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/Dialect/ArmSVE/ArmSVEToLLVMIRTranslation.h"
@@ -37,7 +36,6 @@ class DialectRegistry;
 /// corresponding translation interfaces.
 static inline void registerAllToLLVMIRTranslations(DialectRegistry &registry) {
   registerArmNeonDialectTranslation(registry);
-  registerAMXDialectTranslation(registry);
   registerArmSMEDialectTranslation(registry);
   registerArmSVEDialectTranslation(registry);
   registerBuiltinDialectTranslation(registry);
diff --git a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
index 829f48e223383..69f524e1c311d 100644
--- a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
+++ b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
@@ -11,6 +11,8 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/AMX/AMXDialect.h"
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/DialectImplementation.h"
@@ -21,6 +23,8 @@
 
 using namespace mlir;
 
+#include "mlir/Dialect/AMX/AMXInterfaces.cpp.inc"
+
 #include "mlir/Dialect/AMX/AMXDialect.cpp.inc"
 
 void amx::AMXDialect::initialize() {
@@ -60,24 +64,168 @@ static LogicalResult verifyMultShape(Operation *op, amx::TileType atp,
   return success();
 }
 
+/// Get pointer to a memref descriptor.
+/// Optionally, the base pointer can be offset using linearized index computed
+/// from the given indices.
+static Value getBufferPtr(Location loc, MemRefType type, Value buffer,
+                          ValueRange indices,
+                          const LLVMTypeConverter &typeConverter,
+                          RewriterBase &rewriter) {
+  auto [strides, offset] = type.getStridesAndOffset();
+
+  MemRefDescriptor memRefDescriptor(buffer);
+  Value base = memRefDescriptor.bufferPtr(rewriter, loc, typeConverter, type);
+
+  int numIndices = indices.size();
+  if (numIndices == 0)
+    return base;
+
+  assert(type.getRank() == numIndices &&
+         "expects number of indices equal to memref rank");
+  Value index;
+  Type indexType = typeConverter.getIndexType();
+  for (int i = 0; i < numIndices; ++i) {
+    Value increment = indices[i];
+    if (strides[i] != 1) { // Skip if stride is 1.
+      Value stride =
+          ShapedType::isDynamic(strides[i])
+              ? memRefDescriptor.stride(rewriter, loc, i)
+              : rewriter.create<LLVM::ConstantOp>(
+                    loc, indexType, rewriter.getIndexAttr(strides[i]));
+      increment = rewriter.create<LLVM::MulOp>(loc, increment, stride);
+    }
+    index =
+        index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment;
+  }
+
+  Type elementPtrType = memRefDescriptor.getElementPtrType();
+  return rewriter.create<LLVM::GEPOp>(
+      loc, elementPtrType, typeConverter.convertType(type.getElementType()),
+      base, index);
+}
+
+/// Maps the 2-dim vector shape to the two 16-bit tile sizes. The first
+/// dimension directly translates into the number of rows of the tiles.
+/// The second dimensions needs to be scaled by the number of bytes.
+static SmallVector<Value> getTileSizes(Location loc, amx::TileType tType,
+                                       RewriterBase &rewriter) {
+  Type llvmInt16Type = rewriter.getIntegerType(16);
+  unsigned width = tType.getElementType().getIntOrFloatBitWidth();
+  assert(llvm::isPowerOf2_64(width) && width >= 8);
+  unsigned bytes = width >> 3;
+  auto mattr = rewriter.getI16IntegerAttr(tType.getDimSize(0));
+  auto nattr = rewriter.getI16IntegerAttr(tType.getDimSize(1) * bytes);
+  return SmallVector<Value>{
+      rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, mattr),
+      rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, nattr)};
+}
+
+/// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer
+/// shape may "envelop" the actual tile shape, and may be dynamically sized.
+/// Returns failure if proper stride couldn't be found.
+static Value getStride(Location loc, MemRefType mType, Value base,
+                       RewriterBase &rewriter) {
+  assert(mType.getRank() >= 2 && "Invalid shape for AMX strides");
+  int64_t preLast = mType.getRank() - 2;
+  Type llvmInt64Type = rewriter.getIntegerType(64);
+  unsigned width = mType.getElementType().getIntOrFloatBitWidth();
+  assert(llvm::isPowerOf2_64(width) && width >= 8);
+  unsigned bytes = width >> 3;
+  auto [strides, offset] = mType.getStridesAndOffset();
+  if (strides[preLast] == ShapedType::kDynamic) {
+    // Dynamic stride needs code to compute the ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented May 19, 2025

@llvm/pr-subscribers-mlir

Author: Adam Siemieniuk (adam-smnk)

Changes

Replaces separate amx named intrinsic operations with direct calls to LLVM intrinsic functions.
The existing amx tests are updated and expanded.

The separate conversion step translating amx intrinsics into LLVM IR is eliminated. Instead, this step is now performed by the existing llvm dialect infrastructure.

Related RFC: https://discourse.llvm.org/t/rfc-simplify-x86-intrinsic-generation/85581/7


Patch is 51.27 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/140559.diff

18 Files Affected:

  • (modified) mlir/include/mlir/Dialect/AMX/AMX.td (+71-86)
  • (modified) mlir/include/mlir/Dialect/AMX/AMXDialect.h (+4)
  • (added) mlir/include/mlir/Dialect/AMX/AMXInterfaces.td (+31)
  • (modified) mlir/include/mlir/Dialect/AMX/CMakeLists.txt (+2-3)
  • (modified) mlir/include/mlir/Dialect/AMX/Transforms.h (-3)
  • (modified) mlir/include/mlir/InitAllExtensions.h (-2)
  • (removed) mlir/include/mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h (-31)
  • (modified) mlir/include/mlir/Target/LLVMIR/Dialect/All.h (-2)
  • (modified) mlir/lib/Dialect/AMX/IR/AMXDialect.cpp (+188-2)
  • (modified) mlir/lib/Dialect/AMX/IR/CMakeLists.txt (+1)
  • (modified) mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt (-3)
  • (modified) mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp (+21-203)
  • (modified) mlir/lib/Target/LLVMIR/CMakeLists.txt (-1)
  • (removed) mlir/lib/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.cpp (-56)
  • (removed) mlir/lib/Target/LLVMIR/Dialect/AMX/CMakeLists.txt (-16)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt (-1)
  • (modified) mlir/test/Dialect/AMX/legalize-for-llvm.mlir (+27-27)
  • (modified) mlir/test/Target/LLVMIR/amx.mlir (+87-10)
diff --git a/mlir/include/mlir/Dialect/AMX/AMX.td b/mlir/include/mlir/Dialect/AMX/AMX.td
index 8a51df1ea183f..a484f2ca009a2 100644
--- a/mlir/include/mlir/Dialect/AMX/AMX.td
+++ b/mlir/include/mlir/Dialect/AMX/AMX.td
@@ -25,10 +25,11 @@
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef AMX
-#define AMX
+#ifndef AMX_OPS
+#define AMX_OPS
 
 include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
+include "mlir/Dialect/AMX/AMXInterfaces.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/AttrTypeBase.td"
 include "mlir/IR/BuiltinTypes.td"
@@ -47,8 +48,6 @@ def AMX_Dialect : Dialect {
 
     This `AMX` dialect provides a bridge between MLIR concepts such as
     vectors and memrefs and the lower level LLVM IR support of AMX.
-    The dialect is split into user-facing AMX ops (AMX_Op) and
-    backend-facing intrinsic ops (AMX_IntrOp).
 
     Note that since configuration changes (implicit at dialect level) are
     costly, it is highly recommended to use the AMX dialect on same-shaped
@@ -135,21 +134,17 @@ def AMXTileI8 : AMXTileOf<[I8]>;
 class AMX_Op<string mnemonic, list<Trait> traits = []> :
   Op<AMX_Dialect, mnemonic, traits> {}
 
-// The "internal" intrinsics are meant for compiler usage.
-class AMX_IntrOp<string mnemonic, int numResults, list<Trait> traits = []> :
-  LLVM_IntrOpBase<AMX_Dialect, mnemonic,
-                  "x86_" # !subst(".", "_", mnemonic) # "_internal",
-                  [], [], traits, numResults>;
-
 //===----------------------------------------------------------------------===//
-// AMX Op definitions (user facing).
+// AMX Op definitions
 //===----------------------------------------------------------------------===//
 
 //
 // Tile reset.
 //
 
-def TileZeroOp : AMX_Op<"tile_zero", [Pure]> {
+def TileZeroOp : AMX_Op<"tile_zero", [Pure,
+    AMXIntrinsicOpInterface
+  ]> {
   let summary = "tile zero operation";
   let description = [{
     Zeroes the destination tile, with the shape defined by the 2-dim
@@ -167,6 +162,14 @@ def TileZeroOp : AMX_Op<"tile_zero", [Pure]> {
     TileType getTileType() {
       return ::llvm::cast<TileType>(getRes().getType());
     }
+
+    std::string getIntrinsicName() {
+      return "llvm.x86.tilezero.internal";
+    }
+    SmallVector<Value> getIntrinsicOperands(
+        ::mlir::ArrayRef<Value> operands,
+        const ::mlir::LLVMTypeConverter &typeConverter,
+        ::mlir::RewriterBase &rewriter);
   }];
   let assemblyFormat = "attr-dict `:` qualified(type($res))";
   let hasVerifier = 1;
@@ -176,7 +179,9 @@ def TileZeroOp : AMX_Op<"tile_zero", [Pure]> {
 // Tile memory operations.
 //
 
-def TileLoadOp : AMX_Op<"tile_load", [Pure]> {
+def TileLoadOp : AMX_Op<"tile_load", [Pure,
+    AMXIntrinsicOpInterface
+  ]> {
   let summary = "tile load operation";
   let description = [{
     Loads a tile from memory defined by a base and indices, with the
@@ -200,13 +205,23 @@ def TileLoadOp : AMX_Op<"tile_load", [Pure]> {
     TileType getTileType() {
       return ::llvm::cast<TileType>(getRes().getType());
     }
+
+    std::string getIntrinsicName() {
+      return "llvm.x86.tileloadd64.internal";
+    }
+    SmallVector<Value> getIntrinsicOperands(
+        ::mlir::ArrayRef<Value> operands,
+        const ::mlir::LLVMTypeConverter &typeConverter,
+        ::mlir::RewriterBase &rewriter);
   }];
   let assemblyFormat = "$base `[` $indices `]` attr-dict `:` "
                        "type($base) `into` qualified(type($res))";
   let hasVerifier = 1;
 }
 
-def TileStoreOp : AMX_Op<"tile_store"> {
+def TileStoreOp : AMX_Op<"tile_store", [
+    AMXIntrinsicOpInterface
+  ]> {
   let summary = "tile store operation";
   let description = [{
     Stores a tile to memory defined by a base and indices, with the
@@ -230,6 +245,14 @@ def TileStoreOp : AMX_Op<"tile_store"> {
     TileType getTileType() {
       return ::llvm::cast<TileType>(getVal().getType());
     }
+
+    std::string getIntrinsicName() {
+      return "llvm.x86.tilestored64.internal";
+    }
+    SmallVector<Value> getIntrinsicOperands(
+        ::mlir::ArrayRef<Value> operands,
+        const ::mlir::LLVMTypeConverter &typeConverter,
+        ::mlir::RewriterBase &rewriter);
   }];
   let assemblyFormat = "$base `[` $indices `]` `,` $val attr-dict `:` "
                        "type($base) `,` qualified(type($val))";
@@ -240,8 +263,10 @@ def TileStoreOp : AMX_Op<"tile_store"> {
 // Tile arithmetic operations.
 //
 
-def TileMulFOp : AMX_Op<"tile_mulf", [
-    Pure, AllTypesMatch<["acc", "res"]>]> {
+def TileMulFOp : AMX_Op<"tile_mulf", [Pure,
+    AMXIntrinsicOpInterface,
+    AllTypesMatch<["acc", "res"]>
+  ]> {
   let summary = "tile multiplication operation (floating-point)";
   let description = [{
     Multiplies a "m x k" tile with a "k x n" tile and accumulates the results
@@ -270,6 +295,19 @@ def TileMulFOp : AMX_Op<"tile_mulf", [
     TileType getTileType() {
       return ::llvm::cast<TileType>(getRes().getType());
     }
+
+    std::string getIntrinsicName() {
+      std::string intr = "llvm.x86.tdp";
+      auto elementType =
+        getLhsTileType().getElementType();
+      intr += elementType.isF16() ? "fp16" : "bf16";
+      intr += "ps.internal";
+      return intr;
+    }
+    SmallVector<Value> getIntrinsicOperands(
+        ::mlir::ArrayRef<Value> operands,
+        const ::mlir::LLVMTypeConverter &typeConverter,
+        ::mlir::RewriterBase &rewriter);
   }];
   let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` "
                        "qualified(type($lhs)) `,` qualified(type($rhs))"
@@ -277,8 +315,10 @@ def TileMulFOp : AMX_Op<"tile_mulf", [
   let hasVerifier = 1;
 }
 
-def TileMulIOp : AMX_Op<"tile_muli", [
-    Pure, AllTypesMatch<["acc", "res"]>]> {
+def TileMulIOp : AMX_Op<"tile_muli", [Pure,
+    AMXIntrinsicOpInterface,
+    AllTypesMatch<["acc", "res"]>
+  ]> {
   let summary = "tile multiplication operation (integer)";
   let description = [{
     Multiplies a "m x k" tile with a "k x n" tile and accumulates the results
@@ -313,77 +353,22 @@ def TileMulIOp : AMX_Op<"tile_muli", [
     TileType getTileType() {
       return ::llvm::cast<TileType>(getRes().getType());
     }
+
+    std::string getIntrinsicName() {
+      std::string intr = "llvm.x86.tdpb";
+      intr += getIsZextLhs() ? "u" : "s";
+      intr += getIsZextRhs() ? "u" : "s";
+      intr += "d.internal";
+      return intr;
+    }
+    SmallVector<Value> getIntrinsicOperands(
+        ::mlir::ArrayRef<Value> operands,
+        const ::mlir::LLVMTypeConverter &typeConverter,
+        ::mlir::RewriterBase &rewriter);
   }];
   let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` "
                        "qualified(type($lhs)) `,` qualified(type($rhs)) `,` qualified(type($acc)) ";
   let hasVerifier = 1;
 }
 
-//===----------------------------------------------------------------------===//
-// AMX IntrOp definitions (LLVM compiler facing).
-//===----------------------------------------------------------------------===//
-
-//
-// Tile reset. Parameters define the tile size.
-//
-
-def LLVM_x86_amx_tilezero : AMX_IntrOp<"tilezero", 1>,
-  Arguments<(ins AnyInteger, AnyInteger)>;
-
-//
-// Tile memory operations. Parameters define the tile size,
-// base address, and stride between consecutive rows for the
-// memory operation.
-//
-
-def LLVM_x86_amx_tileloadd64 : AMX_IntrOp<"tileloadd64", 1>,
-  Arguments<(ins AnyInteger,
-                 AnyInteger, LLVM_AnyPointer, AnyInteger)>;
-
-def LLVM_x86_amx_tilestored64 : AMX_IntrOp<"tilestored64", 0>,
-  Arguments<(ins AnyInteger,
-                 AnyInteger, LLVM_AnyPointer, AnyInteger, LLVM_Type)>;
-
-//
-// Tile multiplication operations (series of dot products). Parameters
-// define the tile sizes and source and destination tiles for the
-// operation. Note that the prefix "tdp" stands for tile dot product.
-//
-
-// Dot product of bf16 tiles into f32 tile.
-def LLVM_x86_amx_tdpbf16ps : AMX_IntrOp<"tdpbf16ps", 1>,
-  Arguments<(ins AnyInteger,
-                 AnyInteger,
-		 AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
-
-// Dot product of f16 tiles into f32 tile.
-def LLVM_x86_amx_tdpfp16ps : AMX_IntrOp<"tdpfp16ps", 1>,
-  Arguments<(ins AnyInteger,
-                 AnyInteger,
-		 AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
-
-// Dot product of i8 tiles into i32 tile (with sign/sign extension).
-def LLVM_x86_amx_tdpbssd : AMX_IntrOp<"tdpbssd", 1>,
-  Arguments<(ins AnyInteger,
-                 AnyInteger,
-		 AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
-
-// Dot product of i8 tiles into i32 tile (with sign/zero extension).
-def LLVM_x86_amx_tdpbsud : AMX_IntrOp<"tdpbsud", 1>,
-  Arguments<(ins AnyInteger,
-                 AnyInteger,
-		 AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
-
-// Dot product of i8 tiles into i32 tile (with zero/sign extension).
-def LLVM_x86_amx_tdpbusd : AMX_IntrOp<"tdpbusd", 1>,
-  Arguments<(ins AnyInteger,
-                 AnyInteger,
-		 AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
-
-// Dot product of i8 tiles into i32 tile (with zero/zero extension).
-def LLVM_x86_amx_tdpbuud : AMX_IntrOp<"tdpbuud", 1>,
-  Arguments<(ins AnyInteger,
-                 AnyInteger,
-		 AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
-
-#endif // AMX
+#endif // AMX_OPS
diff --git a/mlir/include/mlir/Dialect/AMX/AMXDialect.h b/mlir/include/mlir/Dialect/AMX/AMXDialect.h
index c0553ad8733fd..c79f31d4c994a 100644
--- a/mlir/include/mlir/Dialect/AMX/AMXDialect.h
+++ b/mlir/include/mlir/Dialect/AMX/AMXDialect.h
@@ -14,11 +14,15 @@
 #define MLIR_DIALECT_AMX_AMXDIALECT_H_
 
 #include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/LLVMIR/LLVMInterfaces.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
+/// Include the generated interface declarations.
+#include "mlir/Dialect/AMX/AMXInterfaces.h.inc"
+
 #include "mlir/Dialect/AMX/AMXDialect.h.inc"
 
 #define GET_TYPEDEF_CLASSES
diff --git a/mlir/include/mlir/Dialect/AMX/AMXInterfaces.td b/mlir/include/mlir/Dialect/AMX/AMXInterfaces.td
new file mode 100644
index 0000000000000..012d1ba7368f7
--- /dev/null
+++ b/mlir/include/mlir/Dialect/AMX/AMXInterfaces.td
@@ -0,0 +1,31 @@
+//===- AMXInterfaces.td - AMX interfaces -------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines interfaces for the AMX dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef AMX_INTERFACES
+#define AMX_INTERFACES
+
+include "mlir/IR/Interfaces.td"
+include "mlir/Dialect/LLVMIR/LLVMInterfaces.td"
+
+//===----------------------------------------------------------------------===//
+// AMX Intrinsic Interface
+//===----------------------------------------------------------------------===//
+
+def AMXIntrinsicOpInterface
+    : OpInterface<"AMXIntrinsicOp", [OneToOneIntrinsicOpInterface]> {
+  let description = [{
+    A wrapper interface for operations representing AMX LLVM intrinsics.
+  }];
+  let cppNamespace = "::mlir::amx";
+}
+
+#endif // AMX_INTERFACES
diff --git a/mlir/include/mlir/Dialect/AMX/CMakeLists.txt b/mlir/include/mlir/Dialect/AMX/CMakeLists.txt
index f3f1aff5a6360..f875c78d240cc 100644
--- a/mlir/include/mlir/Dialect/AMX/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/AMX/CMakeLists.txt
@@ -1,6 +1,5 @@
 add_mlir_dialect(AMX amx)
 add_mlir_doc(AMX AMX Dialects/ -gen-dialect-doc -dialect=amx)
 
-set(LLVM_TARGET_DEFINITIONS AMX.td)
-mlir_tablegen(AMXConversions.inc -gen-llvmir-conversions)
-add_public_tablegen_target(MLIRAMXConversionsIncGen)
+add_mlir_interface(AMXInterfaces)
+add_dependencies(MLIRAMXIncGen MLIRAMXInterfacesIncGen)
diff --git a/mlir/include/mlir/Dialect/AMX/Transforms.h b/mlir/include/mlir/Dialect/AMX/Transforms.h
index 7391ec2ff6b14..4a751d99ceeee 100644
--- a/mlir/include/mlir/Dialect/AMX/Transforms.h
+++ b/mlir/include/mlir/Dialect/AMX/Transforms.h
@@ -25,9 +25,6 @@ void populateAMXLegalizeForLLVMExportPatterns(LLVMTypeConverter &converter,
 /// intrinsics.
 void configureAMXLegalizeForExportTarget(LLVMConversionTarget &target);
 
-/// Register LLVM conversion interface for AMX dialect.
-void registerConvertAMXToLLVMInterface(DialectRegistry &registry);
-
 } // namespace mlir
 
 #endif // MLIR_DIALECT_AMX_TRANSFORMS_H
diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h
index 37e4904cb48ed..1e3f7c649a8bd 100644
--- a/mlir/include/mlir/InitAllExtensions.h
+++ b/mlir/include/mlir/InitAllExtensions.h
@@ -32,7 +32,6 @@
 #include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h"
 #include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
-#include "mlir/Dialect/AMX/Transforms.h"
 #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
 #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
 #include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h"
@@ -84,7 +83,6 @@ inline void registerAllExtensions(DialectRegistry &registry) {
   registerConvertOpenMPToLLVMInterface(registry);
   registerConvertSCFToEmitCInterface(registry);
   ub::registerConvertUBToLLVMInterface(registry);
-  registerConvertAMXToLLVMInterface(registry);
   gpu::registerConvertGpuToLLVMInterface(registry);
   NVVM::registerConvertGpuToNVVMInterface(registry);
   vector::registerConvertVectorToLLVMInterface(registry);
diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h b/mlir/include/mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h
deleted file mode 100644
index 4525ec3212196..0000000000000
--- a/mlir/include/mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h
+++ /dev/null
@@ -1,31 +0,0 @@
-//===- AMXToLLVMIRTranslation.h - AMX to LLVM IR ----------------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This provides registration calls for AMX dialect to LLVM IR translation.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_TARGET_LLVMIR_DIALECT_AMX_AMXTOLLVMIRTRANSLATION_H
-#define MLIR_TARGET_LLVMIR_DIALECT_AMX_AMXTOLLVMIRTRANSLATION_H
-
-namespace mlir {
-
-class DialectRegistry;
-class MLIRContext;
-
-/// Register the AMX dialect and the translation from it to the LLVM IR
-/// in the given registry;
-void registerAMXDialectTranslation(DialectRegistry &registry);
-
-/// Register the AMX dialect and the translation from it in the registry
-/// associated with the given context.
-void registerAMXDialectTranslation(MLIRContext &context);
-
-} // namespace mlir
-
-#endif // MLIR_TARGET_LLVMIR_DIALECT_AMX_AMXTOLLVMIRTRANSLATION_H
diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h
index e043ff2f6825c..60615cf601655 100644
--- a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h
+++ b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h
@@ -14,7 +14,6 @@
 #ifndef MLIR_TARGET_LLVMIR_DIALECT_ALL_H
 #define MLIR_TARGET_LLVMIR_DIALECT_ALL_H
 
-#include "mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/Dialect/ArmNeon/ArmNeonToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/Dialect/ArmSVE/ArmSVEToLLVMIRTranslation.h"
@@ -37,7 +36,6 @@ class DialectRegistry;
 /// corresponding translation interfaces.
 static inline void registerAllToLLVMIRTranslations(DialectRegistry &registry) {
   registerArmNeonDialectTranslation(registry);
-  registerAMXDialectTranslation(registry);
   registerArmSMEDialectTranslation(registry);
   registerArmSVEDialectTranslation(registry);
   registerBuiltinDialectTranslation(registry);
diff --git a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
index 829f48e223383..69f524e1c311d 100644
--- a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
+++ b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
@@ -11,6 +11,8 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/AMX/AMXDialect.h"
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/DialectImplementation.h"
@@ -21,6 +23,8 @@
 
 using namespace mlir;
 
+#include "mlir/Dialect/AMX/AMXInterfaces.cpp.inc"
+
 #include "mlir/Dialect/AMX/AMXDialect.cpp.inc"
 
 void amx::AMXDialect::initialize() {
@@ -60,24 +64,168 @@ static LogicalResult verifyMultShape(Operation *op, amx::TileType atp,
   return success();
 }
 
+/// Get pointer to a memref descriptor.
+/// Optionally, the base pointer can be offset using linearized index computed
+/// from the given indices.
+static Value getBufferPtr(Location loc, MemRefType type, Value buffer,
+                          ValueRange indices,
+                          const LLVMTypeConverter &typeConverter,
+                          RewriterBase &rewriter) {
+  auto [strides, offset] = type.getStridesAndOffset();
+
+  MemRefDescriptor memRefDescriptor(buffer);
+  Value base = memRefDescriptor.bufferPtr(rewriter, loc, typeConverter, type);
+
+  int numIndices = indices.size();
+  if (numIndices == 0)
+    return base;
+
+  assert(type.getRank() == numIndices &&
+         "expects number of indices equal to memref rank");
+  Value index;
+  Type indexType = typeConverter.getIndexType();
+  for (int i = 0; i < numIndices; ++i) {
+    Value increment = indices[i];
+    if (strides[i] != 1) { // Skip if stride is 1.
+      Value stride =
+          ShapedType::isDynamic(strides[i])
+              ? memRefDescriptor.stride(rewriter, loc, i)
+              : rewriter.create<LLVM::ConstantOp>(
+                    loc, indexType, rewriter.getIndexAttr(strides[i]));
+      increment = rewriter.create<LLVM::MulOp>(loc, increment, stride);
+    }
+    index =
+        index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment;
+  }
+
+  Type elementPtrType = memRefDescriptor.getElementPtrType();
+  return rewriter.create<LLVM::GEPOp>(
+      loc, elementPtrType, typeConverter.convertType(type.getElementType()),
+      base, index);
+}
+
+/// Maps the 2-dim vector shape to the two 16-bit tile sizes. The first
+/// dimension directly translates into the number of rows of the tiles.
+/// The second dimensions needs to be scaled by the number of bytes.
+static SmallVector<Value> getTileSizes(Location loc, amx::TileType tType,
+                                       RewriterBase &rewriter) {
+  Type llvmInt16Type = rewriter.getIntegerType(16);
+  unsigned width = tType.getElementType().getIntOrFloatBitWidth();
+  assert(llvm::isPowerOf2_64(width) && width >= 8);
+  unsigned bytes = width >> 3;
+  auto mattr = rewriter.getI16IntegerAttr(tType.getDimSize(0));
+  auto nattr = rewriter.getI16IntegerAttr(tType.getDimSize(1) * bytes);
+  return SmallVector<Value>{
+      rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, mattr),
+      rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, nattr)};
+}
+
+/// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer
+/// shape may "envelop" the actual tile shape, and may be dynamically sized.
+/// Returns failure if proper stride couldn't be found.
+static Value getStride(Location loc, MemRefType mType, Value base,
+                       RewriterBase &rewriter) {
+  assert(mType.getRank() >= 2 && "Invalid shape for AMX strides");
+  int64_t preLast = mType.getRank() - 2;
+  Type llvmInt64Type = rewriter.getIntegerType(64);
+  unsigned width = mType.getElementType().getIntOrFloatBitWidth();
+  assert(llvm::isPowerOf2_64(width) && width >= 8);
+  unsigned bytes = width >> 3;
+  auto [strides, offset] = mType.getStridesAndOffset();
+  if (strides[preLast] == ShapedType::kDynamic) {
+    // Dynamic stride needs code to compute the ...
[truncated]

@adam-smnk
Copy link
Contributor Author

@ienkovich FYI

Copy link
Contributor

@gysit gysit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM from my side.

I added some minor comments/questions.

Copy link

github-actions bot commented May 21, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The generic MLIR changes look good to me. I'm not familiar enough with the AMX dialect to review that part in depth, but this is clearly aligned with the RFC and our previous discussions.

If there are no further comments, lets land this in the next day or two.

Thanks for all the clean-up, Adam!

adam-smnk added 3 commits May 23, 2025 12:35
Replaces separate amx named intrinsic operations with direct calls to
LLVM intrinsic functions.
The existing amx tests are updated and expanded.

The separate conversion step translating amx intrinsics into LLVM IR
is eliminated. Instead, this step is now performed by the existing
llvm dialect infrastructure.

Related RFC: https://discourse.llvm.org/t/rfc-simplify-x86-intrinsic-generation/85581
@adam-smnk adam-smnk merged commit 0fa3ba7 into llvm:main May 23, 2025
11 checks passed
sivan-shani pushed a commit to sivan-shani/llvm-project that referenced this pull request Jun 3, 2025
Replaces separate amx named intrinsic operations with direct calls to
LLVM intrinsic functions.
The existing amx tests are updated and expanded.

The separate conversion step translating amx intrinsics into LLVM IR is
eliminated. Instead, this step is now performed by the existing llvm
dialect infrastructure.

Related RFC:
https://discourse.llvm.org/t/rfc-simplify-x86-intrinsic-generation/85581/7
adam-smnk added a commit that referenced this pull request Jun 12, 2025
Restores mistakenly removed AMX interface which ensures that the custom
tile type is converted to its LLVM equivalent within other operations
such as control flow.

Fix after #140559
rolfmorel added a commit to libxsmm/tpp-mlir that referenced this pull request Jun 12, 2025
* llvm/llvm-project#139340
```
sed -i "s/isZeroIndex/isZeroInteger/g" **/*.h
sed -i "s/isZeroIndex/isZeroInteger/g" **/*.cpp
```

* llvm/llvm-project#141466 &
llvm/llvm-project#141019
  * Add `BufferizationState &state` to `bufferize` and `getBuffer` 

* llvm/llvm-project#143159 &
llvm/llvm-project#142683 &
llvm/llvm-project#143779
  * Updates to `transform.apply_registered_pass` and its Python-bindings

* llvm/llvm-project#143217
* `tilingResult->mergeResult.replacements` ->
`tilingResult->replacements`

* llvm/llvm-project#140559 &
llvm/llvm-project#143871
* Change CHECK lines which expected `amx.` ops to `llvm.call_intrinsic`s
& fix which enables conversion again.
tomtor pushed a commit to tomtor/llvm-project that referenced this pull request Jun 14, 2025
Restores mistakenly removed AMX interface which ensures that the custom
tile type is converted to its LLVM equivalent within other operations
such as control flow.

Fix after llvm#140559
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants