diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h index 2bf9a021f48e1..7e946495e3e7f 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h @@ -69,6 +69,15 @@ SmallVector decomposeValue(OpBuilder &builder, Location loc, Value src, /// function is used to combine multiple values into a single value. Value composeValue(OpBuilder &builder, Location loc, ValueRange src, Type dstType); + +/// Performs the index computation to get to the element at `indices` of the +/// memory pointed to by `memRefDesc`, using the layout map of `type`. +/// The indices are linearized as: +/// `base_offset + index_0 * stride_0 + ... + index_n * stride_n`. +Value getStridedElementPtr( + OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, + MemRefType type, Value memRefDesc, ValueRange indices, + LLVM::GEPNoWrapFlags noWrapFlags = LLVM::GEPNoWrapFlags::none); } // namespace LLVM /// Base class for operation conversions targeting the LLVM IR dialect. It @@ -107,8 +116,8 @@ class ConvertToLLVMPattern : public ConversionPattern { static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value); - // This is a strided getElementPtr variant that linearizes subscripts as: - // `base_offset + index_0 * stride_0 + ... + index_n * stride_n`. + /// Convenience wrapper for the corresponding helper utility. + /// This is a strided getElementPtr variant with linearized subscripts. Value getStridedElementPtr( ConversionPatternRewriter &rewriter, Location loc, MemRefType type, Value memRefDesc, ValueRange indices, diff --git a/mlir/include/mlir/Dialect/AMX/AMX.td b/mlir/include/mlir/Dialect/AMX/AMX.td index 8a51df1ea183f..6bbde43e2d011 100644 --- a/mlir/include/mlir/Dialect/AMX/AMX.td +++ b/mlir/include/mlir/Dialect/AMX/AMX.td @@ -29,6 +29,7 @@ #define AMX include "mlir/Dialect/LLVMIR/LLVMOpBase.td" +include "mlir/Dialect/AMX/AMXInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/AttrTypeBase.td" include "mlir/IR/BuiltinTypes.td" @@ -47,8 +48,6 @@ def AMX_Dialect : Dialect { This `AMX` dialect provides a bridge between MLIR concepts such as vectors and memrefs and the lower level LLVM IR support of AMX. - The dialect is split into user-facing AMX ops (AMX_Op) and - backend-facing intrinsic ops (AMX_IntrOp). Note that since configuration changes (implicit at dialect level) are costly, it is highly recommended to use the AMX dialect on same-shaped @@ -135,21 +134,17 @@ def AMXTileI8 : AMXTileOf<[I8]>; class AMX_Op traits = []> : Op {} -// The "internal" intrinsics are meant for compiler usage. -class AMX_IntrOp traits = []> : - LLVM_IntrOpBase; - //===----------------------------------------------------------------------===// -// AMX Op definitions (user facing). +// AMX Op definitions //===----------------------------------------------------------------------===// // // Tile reset. // -def TileZeroOp : AMX_Op<"tile_zero", [Pure]> { +def TileZeroOp : AMX_Op<"tile_zero", [Pure, + AMXIntrinsicOpInterface + ]> { let summary = "tile zero operation"; let description = [{ Zeroes the destination tile, with the shape defined by the 2-dim @@ -167,6 +162,14 @@ def TileZeroOp : AMX_Op<"tile_zero", [Pure]> { TileType getTileType() { return ::llvm::cast(getRes().getType()); } + + std::string getIntrinsicName() { + return "llvm.x86.tilezero.internal"; + } + SmallVector getIntrinsicOperands( + ::mlir::ArrayRef operands, + const ::mlir::LLVMTypeConverter &typeConverter, + ::mlir::RewriterBase &rewriter); }]; let assemblyFormat = "attr-dict `:` qualified(type($res))"; let hasVerifier = 1; @@ -176,7 +179,9 @@ def TileZeroOp : AMX_Op<"tile_zero", [Pure]> { // Tile memory operations. // -def TileLoadOp : AMX_Op<"tile_load", [Pure]> { +def TileLoadOp : AMX_Op<"tile_load", [Pure, + AMXIntrinsicOpInterface + ]> { let summary = "tile load operation"; let description = [{ Loads a tile from memory defined by a base and indices, with the @@ -200,13 +205,23 @@ def TileLoadOp : AMX_Op<"tile_load", [Pure]> { TileType getTileType() { return ::llvm::cast(getRes().getType()); } + + std::string getIntrinsicName() { + return "llvm.x86.tileloadd64.internal"; + } + SmallVector getIntrinsicOperands( + ::mlir::ArrayRef operands, + const ::mlir::LLVMTypeConverter &typeConverter, + ::mlir::RewriterBase &rewriter); }]; let assemblyFormat = "$base `[` $indices `]` attr-dict `:` " "type($base) `into` qualified(type($res))"; let hasVerifier = 1; } -def TileStoreOp : AMX_Op<"tile_store"> { +def TileStoreOp : AMX_Op<"tile_store", [ + AMXIntrinsicOpInterface + ]> { let summary = "tile store operation"; let description = [{ Stores a tile to memory defined by a base and indices, with the @@ -230,6 +245,14 @@ def TileStoreOp : AMX_Op<"tile_store"> { TileType getTileType() { return ::llvm::cast(getVal().getType()); } + + std::string getIntrinsicName() { + return "llvm.x86.tilestored64.internal"; + } + SmallVector getIntrinsicOperands( + ::mlir::ArrayRef operands, + const ::mlir::LLVMTypeConverter &typeConverter, + ::mlir::RewriterBase &rewriter); }]; let assemblyFormat = "$base `[` $indices `]` `,` $val attr-dict `:` " "type($base) `,` qualified(type($val))"; @@ -240,8 +263,10 @@ def TileStoreOp : AMX_Op<"tile_store"> { // Tile arithmetic operations. // -def TileMulFOp : AMX_Op<"tile_mulf", [ - Pure, AllTypesMatch<["acc", "res"]>]> { +def TileMulFOp : AMX_Op<"tile_mulf", [Pure, + AMXIntrinsicOpInterface, + AllTypesMatch<["acc", "res"]> + ]> { let summary = "tile multiplication operation (floating-point)"; let description = [{ Multiplies a "m x k" tile with a "k x n" tile and accumulates the results @@ -270,6 +295,19 @@ def TileMulFOp : AMX_Op<"tile_mulf", [ TileType getTileType() { return ::llvm::cast(getRes().getType()); } + + std::string getIntrinsicName() { + std::string intr = "llvm.x86.tdp"; + auto elementType = + getLhsTileType().getElementType(); + intr += elementType.isF16() ? "fp16" : "bf16"; + intr += "ps.internal"; + return intr; + } + SmallVector getIntrinsicOperands( + ::mlir::ArrayRef operands, + const ::mlir::LLVMTypeConverter &typeConverter, + ::mlir::RewriterBase &rewriter); }]; let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` " "qualified(type($lhs)) `,` qualified(type($rhs))" @@ -277,8 +315,10 @@ def TileMulFOp : AMX_Op<"tile_mulf", [ let hasVerifier = 1; } -def TileMulIOp : AMX_Op<"tile_muli", [ - Pure, AllTypesMatch<["acc", "res"]>]> { +def TileMulIOp : AMX_Op<"tile_muli", [Pure, + AMXIntrinsicOpInterface, + AllTypesMatch<["acc", "res"]> + ]> { let summary = "tile multiplication operation (integer)"; let description = [{ Multiplies a "m x k" tile with a "k x n" tile and accumulates the results @@ -313,77 +353,22 @@ def TileMulIOp : AMX_Op<"tile_muli", [ TileType getTileType() { return ::llvm::cast(getRes().getType()); } + + std::string getIntrinsicName() { + std::string intr = "llvm.x86.tdpb"; + intr += getIsZextLhs() ? "u" : "s"; + intr += getIsZextRhs() ? "u" : "s"; + intr += "d.internal"; + return intr; + } + SmallVector getIntrinsicOperands( + ::mlir::ArrayRef operands, + const ::mlir::LLVMTypeConverter &typeConverter, + ::mlir::RewriterBase &rewriter); }]; let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` " "qualified(type($lhs)) `,` qualified(type($rhs)) `,` qualified(type($acc)) "; let hasVerifier = 1; } -//===----------------------------------------------------------------------===// -// AMX IntrOp definitions (LLVM compiler facing). -//===----------------------------------------------------------------------===// - -// -// Tile reset. Parameters define the tile size. -// - -def LLVM_x86_amx_tilezero : AMX_IntrOp<"tilezero", 1>, - Arguments<(ins AnyInteger, AnyInteger)>; - -// -// Tile memory operations. Parameters define the tile size, -// base address, and stride between consecutive rows for the -// memory operation. -// - -def LLVM_x86_amx_tileloadd64 : AMX_IntrOp<"tileloadd64", 1>, - Arguments<(ins AnyInteger, - AnyInteger, LLVM_AnyPointer, AnyInteger)>; - -def LLVM_x86_amx_tilestored64 : AMX_IntrOp<"tilestored64", 0>, - Arguments<(ins AnyInteger, - AnyInteger, LLVM_AnyPointer, AnyInteger, LLVM_Type)>; - -// -// Tile multiplication operations (series of dot products). Parameters -// define the tile sizes and source and destination tiles for the -// operation. Note that the prefix "tdp" stands for tile dot product. -// - -// Dot product of bf16 tiles into f32 tile. -def LLVM_x86_amx_tdpbf16ps : AMX_IntrOp<"tdpbf16ps", 1>, - Arguments<(ins AnyInteger, - AnyInteger, - AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>; - -// Dot product of f16 tiles into f32 tile. -def LLVM_x86_amx_tdpfp16ps : AMX_IntrOp<"tdpfp16ps", 1>, - Arguments<(ins AnyInteger, - AnyInteger, - AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>; - -// Dot product of i8 tiles into i32 tile (with sign/sign extension). -def LLVM_x86_amx_tdpbssd : AMX_IntrOp<"tdpbssd", 1>, - Arguments<(ins AnyInteger, - AnyInteger, - AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>; - -// Dot product of i8 tiles into i32 tile (with sign/zero extension). -def LLVM_x86_amx_tdpbsud : AMX_IntrOp<"tdpbsud", 1>, - Arguments<(ins AnyInteger, - AnyInteger, - AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>; - -// Dot product of i8 tiles into i32 tile (with zero/sign extension). -def LLVM_x86_amx_tdpbusd : AMX_IntrOp<"tdpbusd", 1>, - Arguments<(ins AnyInteger, - AnyInteger, - AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>; - -// Dot product of i8 tiles into i32 tile (with zero/zero extension). -def LLVM_x86_amx_tdpbuud : AMX_IntrOp<"tdpbuud", 1>, - Arguments<(ins AnyInteger, - AnyInteger, - AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>; - #endif // AMX diff --git a/mlir/include/mlir/Dialect/AMX/AMXDialect.h b/mlir/include/mlir/Dialect/AMX/AMXDialect.h index c0553ad8733fd..c79f31d4c994a 100644 --- a/mlir/include/mlir/Dialect/AMX/AMXDialect.h +++ b/mlir/include/mlir/Dialect/AMX/AMXDialect.h @@ -14,11 +14,15 @@ #define MLIR_DIALECT_AMX_AMXDIALECT_H_ #include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/Dialect/LLVMIR/LLVMInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +/// Include the generated interface declarations. +#include "mlir/Dialect/AMX/AMXInterfaces.h.inc" + #include "mlir/Dialect/AMX/AMXDialect.h.inc" #define GET_TYPEDEF_CLASSES diff --git a/mlir/include/mlir/Dialect/AMX/AMXInterfaces.td b/mlir/include/mlir/Dialect/AMX/AMXInterfaces.td new file mode 100644 index 0000000000000..012d1ba7368f7 --- /dev/null +++ b/mlir/include/mlir/Dialect/AMX/AMXInterfaces.td @@ -0,0 +1,31 @@ +//===- AMXInterfaces.td - AMX interfaces -------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines interfaces for the AMX dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef AMX_INTERFACES +#define AMX_INTERFACES + +include "mlir/IR/Interfaces.td" +include "mlir/Dialect/LLVMIR/LLVMInterfaces.td" + +//===----------------------------------------------------------------------===// +// AMX Intrinsic Interface +//===----------------------------------------------------------------------===// + +def AMXIntrinsicOpInterface + : OpInterface<"AMXIntrinsicOp", [OneToOneIntrinsicOpInterface]> { + let description = [{ + A wrapper interface for operations representing AMX LLVM intrinsics. + }]; + let cppNamespace = "::mlir::amx"; +} + +#endif // AMX_INTERFACES diff --git a/mlir/include/mlir/Dialect/AMX/CMakeLists.txt b/mlir/include/mlir/Dialect/AMX/CMakeLists.txt index f3f1aff5a6360..f875c78d240cc 100644 --- a/mlir/include/mlir/Dialect/AMX/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/AMX/CMakeLists.txt @@ -1,6 +1,5 @@ add_mlir_dialect(AMX amx) add_mlir_doc(AMX AMX Dialects/ -gen-dialect-doc -dialect=amx) -set(LLVM_TARGET_DEFINITIONS AMX.td) -mlir_tablegen(AMXConversions.inc -gen-llvmir-conversions) -add_public_tablegen_target(MLIRAMXConversionsIncGen) +add_mlir_interface(AMXInterfaces) +add_dependencies(MLIRAMXIncGen MLIRAMXInterfacesIncGen) diff --git a/mlir/include/mlir/Dialect/AMX/Transforms.h b/mlir/include/mlir/Dialect/AMX/Transforms.h index 7391ec2ff6b14..4a751d99ceeee 100644 --- a/mlir/include/mlir/Dialect/AMX/Transforms.h +++ b/mlir/include/mlir/Dialect/AMX/Transforms.h @@ -25,9 +25,6 @@ void populateAMXLegalizeForLLVMExportPatterns(LLVMTypeConverter &converter, /// intrinsics. void configureAMXLegalizeForExportTarget(LLVMConversionTarget &target); -/// Register LLVM conversion interface for AMX dialect. -void registerConvertAMXToLLVMInterface(DialectRegistry ®istry); - } // namespace mlir #endif // MLIR_DIALECT_AMX_TRANSFORMS_H diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h index 37e4904cb48ed..1e3f7c649a8bd 100644 --- a/mlir/include/mlir/InitAllExtensions.h +++ b/mlir/include/mlir/InitAllExtensions.h @@ -32,7 +32,6 @@ #include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h" #include "mlir/Conversion/UBToLLVM/UBToLLVM.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" -#include "mlir/Dialect/AMX/Transforms.h" #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h" #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h" #include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h" @@ -84,7 +83,6 @@ inline void registerAllExtensions(DialectRegistry ®istry) { registerConvertOpenMPToLLVMInterface(registry); registerConvertSCFToEmitCInterface(registry); ub::registerConvertUBToLLVMInterface(registry); - registerConvertAMXToLLVMInterface(registry); gpu::registerConvertGpuToLLVMInterface(registry); NVVM::registerConvertGpuToNVVMInterface(registry); vector::registerConvertVectorToLLVMInterface(registry); diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h b/mlir/include/mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h deleted file mode 100644 index 4525ec3212196..0000000000000 --- a/mlir/include/mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h +++ /dev/null @@ -1,31 +0,0 @@ -//===- AMXToLLVMIRTranslation.h - AMX to LLVM IR ----------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This provides registration calls for AMX dialect to LLVM IR translation. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_TARGET_LLVMIR_DIALECT_AMX_AMXTOLLVMIRTRANSLATION_H -#define MLIR_TARGET_LLVMIR_DIALECT_AMX_AMXTOLLVMIRTRANSLATION_H - -namespace mlir { - -class DialectRegistry; -class MLIRContext; - -/// Register the AMX dialect and the translation from it to the LLVM IR -/// in the given registry; -void registerAMXDialectTranslation(DialectRegistry ®istry); - -/// Register the AMX dialect and the translation from it in the registry -/// associated with the given context. -void registerAMXDialectTranslation(MLIRContext &context); - -} // namespace mlir - -#endif // MLIR_TARGET_LLVMIR_DIALECT_AMX_AMXTOLLVMIRTRANSLATION_H diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h index e043ff2f6825c..60615cf601655 100644 --- a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h +++ b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h @@ -14,7 +14,6 @@ #ifndef MLIR_TARGET_LLVMIR_DIALECT_ALL_H #define MLIR_TARGET_LLVMIR_DIALECT_ALL_H -#include "mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/ArmNeon/ArmNeonToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/ArmSVE/ArmSVEToLLVMIRTranslation.h" @@ -37,7 +36,6 @@ class DialectRegistry; /// corresponding translation interfaces. static inline void registerAllToLLVMIRTranslations(DialectRegistry ®istry) { registerArmNeonDialectTranslation(registry); - registerAMXDialectTranslation(registry); registerArmSMEDialectTranslation(registry); registerArmSVEDialectTranslation(registry); registerBuiltinDialectTranslation(registry); diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp index 48fbcbcdbbde9..86fb9166b7223 100644 --- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp @@ -62,49 +62,8 @@ Value ConvertToLLVMPattern::getStridedElementPtr( ConversionPatternRewriter &rewriter, Location loc, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags) const { - - auto [strides, offset] = type.getStridesAndOffset(); - - MemRefDescriptor memRefDescriptor(memRefDesc); - // Use a canonical representation of the start address so that later - // optimizations have a longer sequence of instructions to CSE. - // If we don't do that we would sprinkle the memref.offset in various - // position of the different address computations. - Value base = - memRefDescriptor.bufferPtr(rewriter, loc, *getTypeConverter(), type); - - LLVM::IntegerOverflowFlags intOverflowFlags = - LLVM::IntegerOverflowFlags::none; - if (LLVM::bitEnumContainsAny(noWrapFlags, LLVM::GEPNoWrapFlags::nusw)) { - intOverflowFlags = intOverflowFlags | LLVM::IntegerOverflowFlags::nsw; - } - if (LLVM::bitEnumContainsAny(noWrapFlags, LLVM::GEPNoWrapFlags::nuw)) { - intOverflowFlags = intOverflowFlags | LLVM::IntegerOverflowFlags::nuw; - } - - Type indexType = getIndexType(); - Value index; - for (int i = 0, e = indices.size(); i < e; ++i) { - Value increment = indices[i]; - if (strides[i] != 1) { // Skip if stride is 1. - Value stride = - ShapedType::isDynamic(strides[i]) - ? memRefDescriptor.stride(rewriter, loc, i) - : createIndexAttrConstant(rewriter, loc, indexType, strides[i]); - increment = rewriter.create(loc, increment, stride, - intOverflowFlags); - } - index = index ? rewriter.create(loc, index, increment, - intOverflowFlags) - : increment; - } - - Type elementPtrType = memRefDescriptor.getElementPtrType(); - return index ? rewriter.create( - loc, elementPtrType, - getTypeConverter()->convertType(type.getElementType()), - base, index, noWrapFlags) - : base; + return LLVM::getStridedElementPtr(rewriter, loc, *getTypeConverter(), type, + memRefDesc, indices, noWrapFlags); } // Check if the MemRefType `type` is supported by the lowering. We currently @@ -524,3 +483,52 @@ Value mlir::LLVM::composeValue(OpBuilder &builder, Location loc, ValueRange src, return res; } + +Value mlir::LLVM::getStridedElementPtr(OpBuilder &builder, Location loc, + const LLVMTypeConverter &converter, + MemRefType type, Value memRefDesc, + ValueRange indices, + LLVM::GEPNoWrapFlags noWrapFlags) { + auto [strides, offset] = type.getStridesAndOffset(); + + MemRefDescriptor memRefDescriptor(memRefDesc); + // Use a canonical representation of the start address so that later + // optimizations have a longer sequence of instructions to CSE. + // If we don't do that we would sprinkle the memref.offset in various + // position of the different address computations. + Value base = memRefDescriptor.bufferPtr(builder, loc, converter, type); + + LLVM::IntegerOverflowFlags intOverflowFlags = + LLVM::IntegerOverflowFlags::none; + if (LLVM::bitEnumContainsAny(noWrapFlags, LLVM::GEPNoWrapFlags::nusw)) { + intOverflowFlags = intOverflowFlags | LLVM::IntegerOverflowFlags::nsw; + } + if (LLVM::bitEnumContainsAny(noWrapFlags, LLVM::GEPNoWrapFlags::nuw)) { + intOverflowFlags = intOverflowFlags | LLVM::IntegerOverflowFlags::nuw; + } + + Type indexType = converter.getIndexType(); + Value index; + for (int i = 0, e = indices.size(); i < e; ++i) { + Value increment = indices[i]; + if (strides[i] != 1) { // Skip if stride is 1. + Value stride = + ShapedType::isDynamic(strides[i]) + ? memRefDescriptor.stride(builder, loc, i) + : builder.create( + loc, indexType, builder.getIndexAttr(strides[i])); + increment = + builder.create(loc, increment, stride, intOverflowFlags); + } + index = index ? builder.create(loc, index, increment, + intOverflowFlags) + : increment; + } + + Type elementPtrType = memRefDescriptor.getElementPtrType(); + return index ? builder.create( + loc, elementPtrType, + converter.convertType(type.getElementType()), base, index, + noWrapFlags) + : base; +} diff --git a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp index 829f48e223383..12b375b373fa9 100644 --- a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp +++ b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp @@ -11,6 +11,8 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/AMX/AMXDialect.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" @@ -21,6 +23,8 @@ using namespace mlir; +#include "mlir/Dialect/AMX/AMXInterfaces.cpp.inc" + #include "mlir/Dialect/AMX/AMXDialect.cpp.inc" void amx::AMXDialect::initialize() { @@ -60,24 +64,127 @@ static LogicalResult verifyMultShape(Operation *op, amx::TileType atp, return success(); } +/// Maps the 2-dim vector shape to the two 16-bit tile sizes. The first +/// dimension directly translates into the number of rows of the tiles. +/// The second dimensions needs to be scaled by the number of bytes. +static SmallVector getTileSizes(Location loc, amx::TileType tType, + RewriterBase &rewriter) { + Type llvmInt16Type = rewriter.getIntegerType(16); + unsigned width = tType.getElementType().getIntOrFloatBitWidth(); + assert(llvm::isPowerOf2_64(width) && width >= 8); + unsigned bytes = width >> 3; + auto mattr = rewriter.getI16IntegerAttr(tType.getDimSize(0)); + auto nattr = rewriter.getI16IntegerAttr(tType.getDimSize(1) * bytes); + return SmallVector{ + rewriter.create(loc, llvmInt16Type, mattr), + rewriter.create(loc, llvmInt16Type, nattr)}; +} + +/// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer +/// shape may "envelop" the actual tile shape, and may be dynamically sized. +static Value getStride(Location loc, MemRefType mType, Value base, + RewriterBase &rewriter) { + assert(mType.getRank() >= 2 && "Invalid shape for AMX strides"); + int64_t preLast = mType.getRank() - 2; + Type llvmInt64Type = rewriter.getIntegerType(64); + unsigned width = mType.getElementType().getIntOrFloatBitWidth(); + assert(llvm::isPowerOf2_64(width) && width >= 8); + unsigned bytes = width >> 3; + auto [strides, offset] = mType.getStridesAndOffset(); + if (strides[preLast] == ShapedType::kDynamic) { + // Dynamic stride needs code to compute the stride at runtime. + MemRefDescriptor memrefDescriptor(base); + auto attr = rewriter.getI64IntegerAttr(bytes); + Value scale = rewriter.create(loc, llvmInt64Type, attr); + return rewriter + .create(loc, llvmInt64Type, scale, + memrefDescriptor.stride(rewriter, loc, preLast)) + .getResult(); + } + // Use direct constant for static stride. + auto attr = rewriter.getI64IntegerAttr(strides[preLast] * bytes); + return rewriter.create(loc, llvmInt64Type, attr) + .getResult(); +} + LogicalResult amx::TileZeroOp::verify() { return verifyTileSize(*this, getTileType()); } +SmallVector +amx::TileZeroOp::getIntrinsicOperands(ArrayRef operands, + const LLVMTypeConverter &typeConverter, + RewriterBase &rewriter) { + return getTileSizes(getLoc(), getTileType(), rewriter); +} + LogicalResult amx::TileLoadOp::verify() { - unsigned rank = getMemRefType().getRank(); + MemRefType memrefTy = getMemRefType(); + unsigned rank = memrefTy.getRank(); + if (rank < 2) + return emitOpError("requires at least 2D memref"); if (getIndices().size() != rank) return emitOpError("requires ") << rank << " indices"; + SmallVector strides; + int64_t offset; + if (failed(memrefTy.getStridesAndOffset(strides, offset)) || + strides.back() != 1) + return emitOpError("requires memref with unit innermost stride"); return verifyTileSize(*this, getTileType()); } +SmallVector +amx::TileLoadOp::getIntrinsicOperands(ArrayRef operands, + const LLVMTypeConverter &typeConverter, + RewriterBase &rewriter) { + auto loc = getLoc(); + Adaptor adaptor(operands, *this); + + SmallVector intrinsicOperands; + intrinsicOperands.append(getTileSizes(loc, getTileType(), rewriter)); + intrinsicOperands.push_back( + LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(), + adaptor.getBase(), adaptor.getIndices())); + intrinsicOperands.push_back( + getStride(loc, getMemRefType(), adaptor.getBase(), rewriter)); + + return intrinsicOperands; +} + LogicalResult amx::TileStoreOp::verify() { - unsigned rank = getMemRefType().getRank(); + MemRefType memrefTy = getMemRefType(); + unsigned rank = memrefTy.getRank(); + if (rank < 2) + return emitOpError("requires at least 2D memref"); if (getIndices().size() != rank) return emitOpError("requires ") << rank << " indices"; + SmallVector strides; + int64_t offset; + if (failed(memrefTy.getStridesAndOffset(strides, offset)) || + strides.back() != 1) + return emitOpError("requires memref with unit innermost stride"); return verifyTileSize(*this, getTileType()); } +SmallVector +amx::TileStoreOp::getIntrinsicOperands(ArrayRef operands, + const LLVMTypeConverter &typeConverter, + RewriterBase &rewriter) { + auto loc = getLoc(); + Adaptor adaptor(operands, *this); + + SmallVector intrinsicOperands; + intrinsicOperands.append(getTileSizes(loc, getTileType(), rewriter)); + intrinsicOperands.push_back( + LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(), + adaptor.getBase(), adaptor.getIndices())); + intrinsicOperands.push_back( + getStride(loc, getMemRefType(), adaptor.getBase(), rewriter)); + intrinsicOperands.push_back(adaptor.getVal()); + + return intrinsicOperands; +} + LogicalResult amx::TileMulFOp::verify() { amx::TileType aType = getLhsTileType(); amx::TileType bType = getRhsTileType(); @@ -95,6 +202,25 @@ LogicalResult amx::TileMulFOp::verify() { return success(); } +SmallVector +amx::TileMulFOp::getIntrinsicOperands(ArrayRef operands, + const LLVMTypeConverter &typeConverter, + RewriterBase &rewriter) { + auto loc = getLoc(); + Adaptor adaptor(operands, *this); + + amx::TileType aType = getLhsTileType(); + amx::TileType bType = getRhsTileType(); + SmallVector tsza = getTileSizes(loc, aType, rewriter); + SmallVector tszb = getTileSizes(loc, bType, rewriter); + + SmallVector intrinsicOperands = {tsza[0], tszb[1], + tsza[1], adaptor.getAcc(), + adaptor.getLhs(), adaptor.getRhs()}; + + return intrinsicOperands; +} + LogicalResult amx::TileMulIOp::verify() { amx::TileType aType = getLhsTileType(); amx::TileType bType = getRhsTileType(); @@ -112,6 +238,25 @@ LogicalResult amx::TileMulIOp::verify() { return success(); } +SmallVector +amx::TileMulIOp::getIntrinsicOperands(ArrayRef operands, + const LLVMTypeConverter &typeConverter, + RewriterBase &rewriter) { + auto loc = getLoc(); + Adaptor adaptor(operands, *this); + + amx::TileType aType = getLhsTileType(); + amx::TileType bType = getRhsTileType(); + SmallVector tsza = getTileSizes(loc, aType, rewriter); + SmallVector tszb = getTileSizes(loc, bType, rewriter); + + SmallVector intrinsicOperands = {tsza[0], tszb[1], + tsza[1], adaptor.getAcc(), + adaptor.getLhs(), adaptor.getRhs()}; + + return intrinsicOperands; +} + Type amx::TileType::parse(AsmParser &parser) { if (parser.parseLess()) return nullptr; diff --git a/mlir/lib/Dialect/AMX/IR/CMakeLists.txt b/mlir/lib/Dialect/AMX/IR/CMakeLists.txt index d109547b2438b..b6e2759843d5e 100644 --- a/mlir/lib/Dialect/AMX/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/AMX/IR/CMakeLists.txt @@ -9,6 +9,7 @@ add_mlir_dialect_library(MLIRAMXDialect LINK_LIBS PUBLIC MLIRIR + MLIRLLVMCommonConversion MLIRLLVMDialect MLIRSideEffectInterfaces ) diff --git a/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt b/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt index 29340d4f45dd1..e827bc475e930 100644 --- a/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt @@ -1,9 +1,6 @@ add_mlir_dialect_library(MLIRAMXTransforms LegalizeForLLVMExport.cpp - DEPENDS - MLIRAMXConversionsIncGen - LINK_LIBS PUBLIC MLIRAMXDialect MLIRIR diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp index 2168409184549..7471dc797e0fc 100644 --- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp @@ -21,224 +21,42 @@ using namespace mlir::amx; namespace { -/// Maps the 2-dim vector shape to the two 16-bit tile sizes. The first -/// dimension directly translates into the number of rows of the tiles. -/// The second dimensions needs to be scaled by the number of bytes. -std::pair getTileSizes(ConversionPatternRewriter &rewriter, - const LLVMTypeConverter &typeConverter, - amx::TileType tType, Location loc) { - Type llvmInt16Type = IntegerType::get(&typeConverter.getContext(), 16); - unsigned width = tType.getElementType().getIntOrFloatBitWidth(); - assert(llvm::isPowerOf2_64(width) && width >= 8); - unsigned bytes = width >> 3; - auto mattr = rewriter.getI16IntegerAttr(tType.getDimSize(0)); - auto nattr = rewriter.getI16IntegerAttr(tType.getDimSize(1) * bytes); - return std::make_pair( - rewriter.create(loc, llvmInt16Type, mattr), - rewriter.create(loc, llvmInt16Type, nattr)); -} - -/// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer -/// shape may "envelop" the actual tile shape, and may be dynamically sized. -/// Returns failure if proper stride couldn't be found. -FailureOr getStride(ConversionPatternRewriter &rewriter, - const LLVMTypeConverter &typeConverter, - MemRefType mType, Value base, Location loc) { - if (mType.getRank() < 2) - return failure(); - int64_t preLast = mType.getRank() - 2; - Type llvmInt64Type = IntegerType::get(&typeConverter.getContext(), 64); - unsigned width = mType.getElementType().getIntOrFloatBitWidth(); - assert(llvm::isPowerOf2_64(width) && width >= 8); - unsigned bytes = width >> 3; - int64_t offset; - SmallVector strides; - if (failed(mType.getStridesAndOffset(strides, offset)) || strides.back() != 1) - return failure(); - if (strides[preLast] == ShapedType::kDynamic) { - // Dynamic stride needs code to compute the stride at runtime. - MemRefDescriptor memrefDescriptor(base); - auto attr = rewriter.getI64IntegerAttr(bytes); - Value scale = rewriter.create(loc, llvmInt64Type, attr); - return rewriter - .create(loc, llvmInt64Type, scale, - memrefDescriptor.stride(rewriter, loc, preLast)) - .getResult(); - } - // Use direct constant for static stride. - auto attr = rewriter.getI64IntegerAttr(strides[preLast] * bytes); - return rewriter.create(loc, llvmInt64Type, attr) - .getResult(); -} - -struct TileZeroConversion : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - LogicalResult - matchAndRewrite(TileZeroOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - amx::TileType tType = op.getTileType(); - // Determine m x n tile sizes. - std::pair tsz = - getTileSizes(rewriter, *getTypeConverter(), tType, op.getLoc()); - // Replace operation with intrinsic. - Type resType = typeConverter->convertType(tType); - rewriter.replaceOpWithNewOp(op, resType, tsz.first, - tsz.second); - return success(); - } -}; - -struct TileLoadConversion : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(TileLoadOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - MemRefType mType = op.getMemRefType(); - amx::TileType tType = op.getTileType(); - // Determine m x n tile sizes. - std::pair tsz = - getTileSizes(rewriter, *getTypeConverter(), tType, op.getLoc()); - // Determine stride. - auto stride = getStride(rewriter, *getTypeConverter(), mType, - adaptor.getBase(), op.getLoc()); - if (failed(stride)) - return failure(); - // Replace operation with intrinsic. - Value ptr = getStridedElementPtr(rewriter, op.getLoc(), mType, - adaptor.getBase(), adaptor.getIndices()); - Type resType = typeConverter->convertType(tType); - rewriter.replaceOpWithNewOp( - op, resType, tsz.first, tsz.second, ptr, stride.value()); - return success(); - } -}; - -struct TileStoreConversion : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(TileStoreOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - MemRefType mType = op.getMemRefType(); - amx::TileType tType = op.getTileType(); - // Determine m x n tile sizes. - std::pair tsz = - getTileSizes(rewriter, *getTypeConverter(), tType, op.getLoc()); - // Determine stride. - auto stride = getStride(rewriter, *getTypeConverter(), mType, - adaptor.getBase(), op.getLoc()); - if (failed(stride)) - return failure(); - // Replace operation with intrinsic. - Value ptr = getStridedElementPtr(rewriter, op.getLoc(), mType, - adaptor.getBase(), adaptor.getIndices()); - rewriter.replaceOpWithNewOp( - op, tsz.first, tsz.second, ptr, stride.value(), adaptor.getVal()); - return success(); - } -}; +/// Generic one-to-one conversion of simply mappable operations into calls +/// to their respective LLVM intrinsics. +struct AMXIntrinsicOpConversion + : public OpInterfaceConversionPattern { + using OpInterfaceConversionPattern< + amx::AMXIntrinsicOp>::OpInterfaceConversionPattern; + + AMXIntrinsicOpConversion(const LLVMTypeConverter &typeConverter, + PatternBenefit benefit = 1) + : OpInterfaceConversionPattern(typeConverter, &typeConverter.getContext(), + benefit), + typeConverter(typeConverter) {} -struct TileMulFConversion : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(TileMulFOp op, OpAdaptor adaptor, + matchAndRewrite(amx::AMXIntrinsicOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - amx::TileType aType = op.getLhsTileType(); - amx::TileType bType = op.getRhsTileType(); - amx::TileType cType = op.getTileType(); - // Determine m x n x k tile sizes. - std::pair tsza = - getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc()); - std::pair tszb = - getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc()); - // Replace operation with intrinsic. - Type resType = typeConverter->convertType(cType); - if (aType.getElementType().isBF16()) - rewriter.replaceOpWithNewOp( - op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(), - adaptor.getLhs(), adaptor.getRhs()); - else if (aType.getElementType().isF16()) - rewriter.replaceOpWithNewOp( - op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(), - adaptor.getLhs(), adaptor.getRhs()); - else - llvm_unreachable("Unexpected element type for amx.mulf"); - return success(); + return LLVM::detail::intrinsicRewrite( + op, rewriter.getStringAttr(op.getIntrinsicName()), + op.getIntrinsicOperands(operands, typeConverter, rewriter), + typeConverter, rewriter); } -}; -struct TileMulIConversion : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - LogicalResult - matchAndRewrite(TileMulIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - amx::TileType aType = op.getLhsTileType(); - amx::TileType bType = op.getRhsTileType(); - amx::TileType cType = op.getTileType(); - // Determine m x n x k tile sizes. - std::pair tsza = - getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc()); - std::pair tszb = - getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc()); - // Replace operation with intrinsic. - Type resType = typeConverter->convertType(cType); - bool zexta = op.getIsZextLhs(); - bool zextb = op.getIsZextRhs(); - if (zexta && zextb) - rewriter.replaceOpWithNewOp( - op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(), - adaptor.getLhs(), adaptor.getRhs()); - else if (zexta && !zextb) - rewriter.replaceOpWithNewOp( - op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(), - adaptor.getLhs(), adaptor.getRhs()); - else if (!zexta && zextb) - rewriter.replaceOpWithNewOp( - op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(), - adaptor.getLhs(), adaptor.getRhs()); - else - rewriter.replaceOpWithNewOp( - op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(), - adaptor.getLhs(), adaptor.getRhs()); - return success(); - } +private: + const LLVMTypeConverter &typeConverter; }; } // namespace void mlir::populateAMXLegalizeForLLVMExportPatterns( LLVMTypeConverter &converter, RewritePatternSet &patterns) { - patterns.add(converter); + patterns.add(converter); converter.addConversion([&](amx::TileType type) { return LLVM::LLVMX86AMXType::get(&converter.getContext()); }); } void mlir::configureAMXLegalizeForExportTarget(LLVMConversionTarget &target) { - target.addLegalOp(); - target.addIllegalOp(); -} - -namespace { -/// Implement the interface to convert AMX to LLVM. -struct AMXToLLVMDialectInterface : public ConvertToLLVMPatternInterface { - using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; - - void populateConvertToLLVMConversionPatterns( - ConversionTarget &target, LLVMTypeConverter &typeConverter, - RewritePatternSet &patterns) const final { - populateAMXLegalizeForLLVMExportPatterns(typeConverter, patterns); - } -}; -} // namespace - -void mlir::registerConvertAMXToLLVMInterface(DialectRegistry ®istry) { - registry.addExtension(+[](MLIRContext *ctx, amx::AMXDialect *dialect) { - dialect->addInterfaces(); - }); + target.addIllegalDialect(); } diff --git a/mlir/lib/Target/LLVMIR/CMakeLists.txt b/mlir/lib/Target/LLVMIR/CMakeLists.txt index 4ace3964e8ae0..af22a7ff04bf0 100644 --- a/mlir/lib/Target/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Target/LLVMIR/CMakeLists.txt @@ -51,7 +51,6 @@ add_mlir_translation_library(MLIRToLLVMIRTranslationRegistration MLIRArmNeonToLLVMIRTranslation MLIRArmSMEToLLVMIRTranslation MLIRArmSVEToLLVMIRTranslation - MLIRAMXToLLVMIRTranslation MLIRBuiltinToLLVMIRTranslation MLIRGPUToLLVMIRTranslation MLIRLLVMToLLVMIRTranslation diff --git a/mlir/lib/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.cpp deleted file mode 100644 index 044462d33cfd1..0000000000000 --- a/mlir/lib/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.cpp +++ /dev/null @@ -1,56 +0,0 @@ -//===- AMXToLLVMIRTranslation.cpp - Translate AMX to LLVM IR --------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements a translation between the AMX dialect and LLVM IR. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h" -#include "mlir/Dialect/AMX/AMXDialect.h" -#include "mlir/IR/Operation.h" -#include "mlir/Target/LLVMIR/ModuleTranslation.h" - -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/IntrinsicsX86.h" - -using namespace mlir; -using namespace mlir::LLVM; - -namespace { -/// Implementation of the dialect interface that converts operations belonging -/// to the AMX dialect to LLVM IR. -class AMXDialectLLVMIRTranslationInterface - : public LLVMTranslationDialectInterface { -public: - using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface; - - /// Translates the given operation to LLVM IR using the provided IR builder - /// and saving the state in `moduleTranslation`. - LogicalResult - convertOperation(Operation *op, llvm::IRBuilderBase &builder, - LLVM::ModuleTranslation &moduleTranslation) const final { - Operation &opInst = *op; -#include "mlir/Dialect/AMX/AMXConversions.inc" - - return failure(); - } -}; -} // namespace - -void mlir::registerAMXDialectTranslation(DialectRegistry ®istry) { - registry.insert(); - registry.addExtension(+[](MLIRContext *ctx, amx::AMXDialect *dialect) { - dialect->addInterfaces(); - }); -} - -void mlir::registerAMXDialectTranslation(MLIRContext &context) { - DialectRegistry registry; - registerAMXDialectTranslation(registry); - context.appendDialectRegistry(registry); -} diff --git a/mlir/lib/Target/LLVMIR/Dialect/AMX/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/AMX/CMakeLists.txt deleted file mode 100644 index 733b4c2e31b80..0000000000000 --- a/mlir/lib/Target/LLVMIR/Dialect/AMX/CMakeLists.txt +++ /dev/null @@ -1,16 +0,0 @@ -add_mlir_translation_library(MLIRAMXToLLVMIRTranslation - AMXToLLVMIRTranslation.cpp - - DEPENDS - MLIRAMXConversionsIncGen - - LINK_COMPONENTS - Core - - LINK_LIBS PUBLIC - MLIRIR - MLIRAMXDialect - MLIRLLVMDialect - MLIRSupport - MLIRTargetLLVMIRExport - ) diff --git a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt index 40df6e3f4b642..f030fa78942d5 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt +++ b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt @@ -1,7 +1,6 @@ add_subdirectory(ArmNeon) add_subdirectory(ArmSME) add_subdirectory(ArmSVE) -add_subdirectory(AMX) add_subdirectory(Builtin) add_subdirectory(GPU) add_subdirectory(LLVMIR) diff --git a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir index 8085f5f59fcaf..7e562b00a46a9 100644 --- a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir +++ b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir @@ -1,17 +1,17 @@ // RUN: mlir-opt %s -convert-vector-to-llvm="enable-amx" | mlir-opt | FileCheck %s // CHECK-LABEL: muli( -// CHECK: amx.tilezero -// CHECK: amx.tileloadd64 -// CHECK: amx.tileloadd64 -// CHECK: amx.tdpbuud -// CHECK: amx.tilestored64 -// CHECK: amx.tdpbssd -// CHECK: amx.tilestored64 -// CHECK: amx.tdpbusd -// CHECK: amx.tilestored64 -// CHECK: amx.tdpbsud -// CHECK: amx.tilestored64 +// CHECK: llvm.call_intrinsic "llvm.x86.tilezero.internal" +// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal" +// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal" +// CHECK: llvm.call_intrinsic "llvm.x86.tdpbuud.internal" +// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal" +// CHECK: llvm.call_intrinsic "llvm.x86.tdpbssd.internal" +// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal" +// CHECK: llvm.call_intrinsic "llvm.x86.tdpbusd.internal" +// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal" +// CHECK: llvm.call_intrinsic "llvm.x86.tdpbsud.internal" +// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal" func.func @muli(%arg0: memref, %arg1: memref) { %0 = arith.constant 0 : index %1 = amx.tile_zero : !amx.tile<16x64xi8> @@ -29,11 +29,11 @@ func.func @muli(%arg0: memref, %arg1: memref) { } // CHECK-LABEL: mulbf16( -// CHECK: amx.tilezero -// CHECK: amx.tileloadd64 -// CHECK: amx.tileloadd64 -// CHECK: amx.tdpbf16ps -// CHECK: amx.tilestored64 +// CHECK: llvm.call_intrinsic "llvm.x86.tilezero.internal" +// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal" +// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal" +// CHECK: llvm.call_intrinsic "llvm.x86.tdpbf16ps.internal" +// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal" func.func @mulbf16(%arg0: memref, %arg1: memref) { %0 = arith.constant 0 : index %1 = amx.tile_zero : !amx.tile<16x32xbf16> @@ -45,11 +45,11 @@ func.func @mulbf16(%arg0: memref, %arg1: memref) { } // CHECK-LABEL: mulfp16( -// CHECK: amx.tilezero -// CHECK: amx.tileloadd64 -// CHECK: amx.tileloadd64 -// CHECK: amx.tdpfp16ps -// CHECK: amx.tilestored64 +// CHECK: llvm.call_intrinsic "llvm.x86.tilezero.internal" +// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal" +// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal" +// CHECK: llvm.call_intrinsic "llvm.x86.tdpfp16ps.internal" +// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal" func.func @mulfp16(%arg0: memref, %arg1: memref) { %0 = arith.constant 0 : index %1 = amx.tile_zero : !amx.tile<16x32xf16> @@ -62,21 +62,21 @@ func.func @mulfp16(%arg0: memref, %arg1: memref) { // CHECK-LABEL: strides( // CHECK: %[[CST_64_1:.+]] = llvm.mlir.constant(64 : i64) : i64 -// CHECK: "amx.tileloadd64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_64_1]] +// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_64_1]] // CHECK: %[[CST_128_1:.+]] = llvm.mlir.constant(128 : i64) : i64 -// CHECK: "amx.tileloadd64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_128_1]] +// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_128_1]] // CHECK: llvm.mlir.constant(2 : i64) : i64 // CHECK: llvm.extractvalue %{{.+}}[4, 0] // CHECK: %[[STRIDE_1:.+]] = llvm.mul -// CHECK: "amx.tileloadd64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STRIDE_1]] +// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STRIDE_1]] // CHECK: %[[CST_64_2:.+]] = llvm.mlir.constant(64 : i64) : i64 -// CHECK: "amx.tilestored64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_64_2]] +// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_64_2]] // CHECK: %[[CST_128_2:.+]] = llvm.mlir.constant(128 : i64) : i64 -// CHECK: "amx.tilestored64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_128_2]] +// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_128_2]] // CHECK: llvm.mlir.constant(2 : i64) : i64 // CHECK: llvm.extractvalue %{{.+}}[4, 0] // CHECK: %[[STRIDE_2:.+]] = llvm.mul -// CHECK: "amx.tilestored64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STRIDE_2]] +// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STRIDE_2]] func.func @strides(%arg0: memref<16x32xbf16>, %arg1: memref<16x32xbf16, strided<[64, 1]>>, %arg2: memref<16x32xbf16, strided<[?, 1]>>) { %0 = arith.constant 0 : index %1 = amx.tile_load %arg0[%0, %0] : memref<16x32xbf16> into !amx.tile<16x32xbf16> diff --git a/mlir/test/Target/LLVMIR/amx.mlir b/mlir/test/Target/LLVMIR/amx.mlir index 0281dfcd6ad69..094475040436d 100644 --- a/mlir/test/Target/LLVMIR/amx.mlir +++ b/mlir/test/Target/LLVMIR/amx.mlir @@ -1,13 +1,90 @@ -// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s +// RUN: mlir-opt %s --convert-vector-to-llvm="enable-amx" --convert-to-llvm -reconcile-unrealized-casts \ +// RUN: | mlir-translate --mlir-to-llvmir \ +// RUN: | FileCheck %s -// CHECK-LABEL: define void @target(ptr %0) -// CHECK: %[[c:.*]] = call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 16) -// CHECK: call void @llvm.x86.tilestored64.internal(i16 16, i16 16, ptr %0, i64 32, x86_amx %[[c]] -llvm.func @target(%ptr: !llvm.ptr) { - %c = llvm.mlir.constant(16 : i16) : i16 - %s = llvm.mlir.constant(32 : i64) : i64 - %0 = "amx.tilezero"(%c, %c) : (i16, i16) -> !llvm.array<16 x vector<16xbf16>> - "amx.tilestored64"(%c, %c, %ptr, %s, %0) : (i16, i16, !llvm.ptr, i64, !llvm.array<16 x vector<16xbf16>>) -> () - llvm.return +// CHECK-LABEL: define void @amx_tile_zero +func.func @amx_tile_zero(%out: memref, %idx: index) +{ + // CHECK: call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 64) + // CHECK: call void @llvm.x86.tilestored64.internal + %zero = amx.tile_zero : !amx.tile<16x16xf32> + amx.tile_store %out[%idx, %idx], %zero : memref, !amx.tile<16x16xf32> + return } +// CHECK-LABEL: define void @amx_tile_load_store +func.func @amx_tile_load_store(%base: memref, %out: memref, + %idx: index) +{ + // CHECK: call x86_amx @llvm.x86.tileloadd64.internal + // CHECK: call void @llvm.x86.tilestored64.internal + %val = amx.tile_load %base[%idx, %idx] : memref into !amx.tile<16x64xi8> + amx.tile_store %out[%idx, %idx], %val : memref, !amx.tile<16x64xi8> + return +} + +// CHECK-LABEL: define void @amx_tile_mulf_bf16 +func.func @amx_tile_mulf_bf16( + %matA: memref, %matB: memref, %idx: index, + %out: memref) +{ + // CHECK: call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 64) + %acc = amx.tile_zero : !amx.tile<16x16xf32> + // CHECK-COUNT-2: call x86_amx @llvm.x86.tileloadd64.internal + %tA = amx.tile_load %matA[%idx, %idx] : memref into !amx.tile<16x32xbf16> + %tB = amx.tile_load %matB[%idx, %idx] : memref into !amx.tile<16x32xbf16> + // CHECK: call x86_amx @llvm.x86.tdpbf16ps.internal + %tRes = amx.tile_mulf %tA, %tB, %acc + : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32> + // CHECK: call void @llvm.x86.tilestored64.internal + amx.tile_store %out[%idx, %idx], %tRes : memref, !amx.tile<16x16xf32> + return +} + +// CHECK-LABEL: define void @amx_tile_mulf_f16 +func.func @amx_tile_mulf_f16( + %matA: memref, %matB: memref, %idx: index, + %out: memref) +{ + // CHECK: call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 64) + %acc = amx.tile_zero : !amx.tile<16x16xf32> + // CHECK-COUNT-2: call x86_amx @llvm.x86.tileloadd64.internal + %tA = amx.tile_load %matA[%idx, %idx] : memref into !amx.tile<16x32xf16> + %tB = amx.tile_load %matB[%idx, %idx] : memref into !amx.tile<16x32xf16> + // CHECK: call x86_amx @llvm.x86.tdpfp16ps.internal + %tRes = amx.tile_mulf %tA, %tB, %acc + : !amx.tile<16x32xf16>, !amx.tile<16x32xf16>, !amx.tile<16x16xf32> + // CHECK: call void @llvm.x86.tilestored64.internal + amx.tile_store %out[%idx, %idx], %tRes : memref, !amx.tile<16x16xf32> + return +} + +// CHECK-LABEL: define void @amx_tile_muli +func.func @amx_tile_muli(%matA: memref, %matB: memref, + %matC: memref, %idx: index, %out: memref) +{ + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + // CHECK-COUNT-3: call x86_amx @llvm.x86.tileloadd64.internal + %tA = amx.tile_load %matA[%idx, %idx] : memref into !amx.tile<16x64xi8> + %tB = amx.tile_load %matB[%idx, %idx] : memref into !amx.tile<16x64xi8> + %acc = amx.tile_load %matC[%idx, %idx] : memref into !amx.tile<16x16xi32> + // CHECK: call x86_amx @llvm.x86.tdpbuud.internal + // CHECK: call x86_amx @llvm.x86.tdpbssd.internal + // CHECK: call x86_amx @llvm.x86.tdpbusd.internal + // CHECK: call x86_amx @llvm.x86.tdpbsud.internal + %res = amx.tile_muli %tA zext, %tB zext, %acc + : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32> + %res1 = amx.tile_muli %tA, %tB, %acc + : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32> + %res2 = amx.tile_muli %tA zext, %tB, %acc + : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32> + %res3 = amx.tile_muli %tA, %tB zext, %acc + : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32> + // CHECK-COUNT-4: call void @llvm.x86.tilestored64.internal + amx.tile_store %out[%c0, %c0], %res : memref, !amx.tile<16x16xi32> + amx.tile_store %out[%c0, %c16], %res1 : memref, !amx.tile<16x16xi32> + amx.tile_store %out[%c16, %c0], %res2 : memref, !amx.tile<16x16xi32> + amx.tile_store %out[%c16, %c16], %res3 : memref, !amx.tile<16x16xi32> + return +}