diff --git a/build_tools/ci/cpu_comparison/run.py b/build_tools/ci/cpu_comparison/run.py index fdacf7cbc..5433a46af 100755 --- a/build_tools/ci/cpu_comparison/run.py +++ b/build_tools/ci/cpu_comparison/run.py @@ -783,6 +783,29 @@ def run(self, config): output_type=get_output_type(test_name), ) + # Large shape Matmul + Truncf + generate_matmul_test(test_name, template_name, 128, 128, 256, "bf16", "f32") + identity_mat = np.eye(128, dtype=np.float32) + ones = np.ones(128 * 128, dtype=np.float32).reshape([128, 128]) + lhs = ones * 101 + rhs = identity_mat * 3 + input_args = generate_inputs(test_name, output_dir, 1, {1: lhs, 2: rhs}) + aie_vs_baseline( + config, + test_name, + input_args, + ones * 302, # exected output + use_ukernel=False, + tile_pipeline="pack-peel", + lower_to_aie_pipeline="objectFifo", + function_name=None, + seed=1, + rtol=0, + atol=0, + n_repeats=1, + output_type=get_output_type(test_name), + ) + class SmokeSet(TestSet): def __init__(self): diff --git a/compiler/plugins/target/AMD-AIE/aievec/VectorToVectorConversions.cpp b/compiler/plugins/target/AMD-AIE/aievec/VectorToVectorConversions.cpp index fed15b694..c9c3918c9 100644 --- a/compiler/plugins/target/AMD-AIE/aievec/VectorToVectorConversions.cpp +++ b/compiler/plugins/target/AMD-AIE/aievec/VectorToVectorConversions.cpp @@ -26,6 +26,7 @@ #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" @@ -39,6 +40,73 @@ namespace mlir::iree_compiler::aievec { using namespace mlir; +struct CanonicalizeTrivialReadAccessSubviewOpPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferReadOp readOp, + PatternRewriter &rewriter) const override { + auto subViewOp = dyn_cast_if_present( + readOp.getSource().getDefiningOp()); + if (!subViewOp) return failure(); + if (!llvm::all_of(readOp.getIndices(), [](Value val) { + IntegerAttr attr; + if (!matchPattern(val, m_Constant(&attr))) return false; + return attr.getInt() == 0; + })) + return failure(); + SmallVector newIndices; + for (OpFoldResult x : subViewOp.getMixedOffsets()) { + Value indexVal; + if (auto attr = dyn_cast(x)) { + indexVal = rewriter.create(readOp.getLoc(), + cast(attr)); + } else { + indexVal = cast(x); + } + newIndices.push_back(indexVal); + } + rewriter.replaceOpWithNewOp( + readOp, readOp.getType(), subViewOp.getSource(), newIndices, + readOp.getPadding(), readOp.getInBoundsValues()); + return success(); + } +}; + +struct CanonicalizeTrivialWriteAccessSubviewOpPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, + PatternRewriter &rewriter) const override { + auto subViewOp = dyn_cast_if_present( + writeOp.getSource().getDefiningOp()); + if (!subViewOp) return failure(); + if (!llvm::all_of(writeOp.getIndices(), [](Value val) { + IntegerAttr attr; + if (!matchPattern(val, m_Constant(&attr))) return false; + return attr.getInt() == 0; + })) + return failure(); + SmallVector newIndices; + for (OpFoldResult x : subViewOp.getMixedOffsets()) { + Value indexVal; + if (auto attr = dyn_cast(x)) { + indexVal = rewriter.create(writeOp.getLoc(), + cast(attr)); + } else { + indexVal = cast(x); + } + newIndices.push_back(indexVal); + } + rewriter.create( + writeOp.getLoc(), writeOp.getVector(), subViewOp.getSource(), + newIndices, writeOp.getInBoundsValues()); + rewriter.eraseOp(writeOp); + return success(); + } +}; + static bool isGemmBTransposedContractionOp(vector::ContractionOp op) { if (op.getKind() != vector::CombiningKind::ADD) return false; @@ -628,6 +696,12 @@ struct CanonicalizeVectorForAIEVecPass auto op = getOperation(); MLIRContext *context = &getContext(); + { + RewritePatternSet patterns(context); + patterns.add(context); + (void)applyPatternsAndFoldGreedily(op, std::move(patterns)); + } { // These must run before 'populateVectorBroadcastLoweringPatterns' // so that broadcasts can be matched before conversion to insert. diff --git a/compiler/plugins/target/AMD-AIE/aievec/test/precanonicalization-aieml-llvmir.mlir b/compiler/plugins/target/AMD-AIE/aievec/test/precanonicalization-aieml-llvmir.mlir index bc80f51e6..408614e96 100644 --- a/compiler/plugins/target/AMD-AIE/aievec/test/precanonicalization-aieml-llvmir.mlir +++ b/compiler/plugins/target/AMD-AIE/aievec/test/precanonicalization-aieml-llvmir.mlir @@ -167,3 +167,40 @@ func.func @arith_truncf(%inp: vector<2x3xf32>) -> vector<2x3xbf16> { %0 = arith.truncf %inp : vector<2x3xf32> to vector<2x3xbf16> return %0 : vector<2x3xbf16> } + +// ----- + +// CHECK-LABEL: @trivial_read_access +// CHECK-SAME: (%[[ARG0:.*]]: memref<4x8x4x8xbf16, strided<[256, 32, 8, 1]>>) +// CHECK: %[[COLLAPSE_SHAPE:.*]] = memref.collapse_shape %[[ARG0]] +// CHECK-SAME: into memref<1024xbf16, strided<[1]>> +// CHECK: %[[READ:.*]] = vector.transfer_read %[[COLLAPSE_SHAPE]] +// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[READ]] +// CHECK-SAME: vector<32xbf16> to vector<1x1x4x8xbf16> +// CHECK: return %[[SHAPE_CAST]] +func.func @trivial_read_access(%arg0: memref<4x8x4x8xbf16, strided<[256, 32, 8, 1]>>) -> vector<1x1x4x8xbf16> { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : bf16 + %subview = memref.subview %arg0[2, 3, 0, 0] [1, 1, 4, 8] [1, 1, 1, 1] : memref<4x8x4x8xbf16, strided<[256, 32, 8, 1]>> to memref<1x1x4x8xbf16, strided<[256, 32, 8, 1], offset: 608>> + %read = vector.transfer_read %subview[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true]} : memref<1x1x4x8xbf16, strided<[256, 32, 8, 1], offset: 608>>, vector<1x1x4x8xbf16> + return %read : vector<1x1x4x8xbf16> +} + +// ----- + +// CHECK-LABEL: @trivial_write_access +// CHECK-SAME: (%[[ARG0:.*]]: memref<8x8x4x4xf32, strided<[128, 16, 4, 1]>>, +// CHECK-SAME: %[[ARG1:.*]]: vector<1x1x4x4xf32>) +// CHECK: %[[COLLAPSE_SHAPE:.*]] = memref.collapse_shape %[[ARG0]] +// CHECK-SAME: : memref<8x8x4x4xf32, strided<[128, 16, 4, 1]>> into memref<1024xf32, strided<[1]>> +// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG1]] +// CHECK-SAME: : vector<1x1x4x4xf32> to vector<16xf32> +// CHECK: vector.transfer_write %[[SHAPE_CAST]], %[[COLLAPSE_SHAPE]] +// CHECK: return +func.func @trivial_write_access(%arg0: memref<8x8x4x4xf32, strided<[128, 16, 4, 1]>>, %arg1: vector<1x1x4x4xf32>) { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : bf16 + %subview = memref.subview %arg0[2, 3, 0, 0] [1, 1, 4, 4] [1, 1, 1, 1] : memref<8x8x4x4xf32, strided<[128, 16, 4, 1]>> to memref<1x1x4x4xf32, strided<[128, 16, 4, 1], offset: 304>> + vector.transfer_write %arg1, %subview[%c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true]} : vector<1x1x4x4xf32>, memref<1x1x4x4xf32, strided<[128, 16, 4, 1], offset: 304>> + return +} diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Target/tests/cdo/aie_cdo_elfs.bin b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Target/tests/cdo/aie_cdo_elfs.bin new file mode 100644 index 000000000..cba6b8778 Binary files /dev/null and b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Target/tests/cdo/aie_cdo_elfs.bin differ diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Target/tests/cdo/aie_cdo_init.bin b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Target/tests/cdo/aie_cdo_init.bin new file mode 100644 index 000000000..8beb02ddd Binary files /dev/null and b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Target/tests/cdo/aie_cdo_init.bin differ diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEFunctionOutlining.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEFunctionOutlining.cpp new file mode 100644 index 000000000..e764a86ad --- /dev/null +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEFunctionOutlining.cpp @@ -0,0 +1,112 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree-amd-aie/IR/AMDAIEOps.h" +#include "iree-amd-aie/Transforms/AMDAIEUtils.h" +#include "iree-amd-aie/Transforms/Passes.h" +#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" +#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" +#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" +#include "mlir/Dialect/Linalg/Transforms/Hoisting.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" + +#define DEBUG_TYPE "iree-amdaie-function-outlining" + +namespace mlir::iree_compiler::AMDAIE { + +namespace { + +class AMDAIEFunctionOutliningPass + : public impl::AMDAIEFunctionOutliningBase { + public: + AMDAIEFunctionOutliningPass() = default; + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override; +}; + +void AMDAIEFunctionOutliningPass::runOnOperation() { + ModuleOp moduleOp = getOperation(); + MLIRContext *context = &getContext(); + IRRewriter rewriter(context); + + auto outlinedToAFunction = [&](linalg::LinalgOp &computeOp) -> func::FuncOp { + // Form outlined FuncName. + std::string computeName = ""; + if (isMatmul(computeOp)) { + computeName = "_matmul"; + } else { + // TODO(avarma): Make this better/general. + computeName = "_elementwise"; + } + std::string outlinedFuncName = + computeOp->getName().stripDialect().str() + computeName + "_outlined"; + if (auto outlinedFuncOp = dyn_cast_if_present( + moduleOp.lookupSymbol(outlinedFuncName))) + return outlinedFuncOp; + + // Form outlined FunctionType. + SmallVector inputTypes = llvm::map_to_vector( + computeOp.getDpsInputs(), [](Value v) { return v.getType(); }); + for (Value val : computeOp.getDpsInits()) + inputTypes.push_back(val.getType()); + auto outlinedFuncType = + FunctionType::get(rewriter.getContext(), inputTypes, {}); + + // Form outlined FuncSignature + rewriter.setInsertionPointToStart(moduleOp.getBody()); + auto outlinedFunc = rewriter.create( + moduleOp.getLoc(), outlinedFuncName, outlinedFuncType); + outlinedFunc.setPrivate(); + + // Create an entry func block and map the original operands of the compute + // op to the block arguments. + Block *outlinedFuncBody = outlinedFunc.addEntryBlock(); + rewriter.setInsertionPointToStart(outlinedFuncBody); + SmallVector outlinedFuncArgs = + llvm::map_to_vector(outlinedFunc.getArguments(), + [&](BlockArgument bbArg) { return bbArg; }); + unsigned bbArgIndex = 0; + IRMapping operandMap; + for (Value origOperand : computeOp.getDpsInputs()) { + operandMap.map(origOperand, outlinedFuncArgs[bbArgIndex++]); + } + for (Value origOperand : computeOp.getDpsInits()) { + operandMap.map(origOperand, outlinedFuncArgs[bbArgIndex++]); + } + + // Clone the compute op while mapping the operand to the function block + // arguments. + Operation *clonedComputeOp = rewriter.clone(*computeOp, operandMap); + + // Create terminator op returning the cloned compute op's results. + rewriter.setInsertionPointToEnd(outlinedFuncBody); + rewriter.create(clonedComputeOp->getLoc(), ValueRange({})); + + return outlinedFunc; + }; + + moduleOp.walk([&](linalg::LinalgOp computeOp) { + if (isa(computeOp)) + return WalkResult::skip(); + func::FuncOp outlinedFuncOp = outlinedToAFunction(computeOp); + rewriter.setInsertionPoint(computeOp); + rewriter.create(computeOp.getLoc(), outlinedFuncOp, + computeOp->getOperands()); + rewriter.eraseOp(computeOp); + return WalkResult::advance(); + }); +} + +} // namespace + +std::unique_ptr createAMDAIEFunctionOutliningPass() { + return std::make_unique(); +} +} // namespace mlir::iree_compiler::AMDAIE diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIELowerToAIE.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIELowerToAIE.cpp index 0edeb3659..c70094767 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIELowerToAIE.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIELowerToAIE.cpp @@ -297,6 +297,7 @@ LogicalResult AIEDeviceBuilder::coreFuncCallOpToAIE( SymbolTable::setSymbolVisibility(newFnDecl, SymbolTable::Visibility::Private); newFnDecl->setAttr("llvm.bareptr", rewriter.getBoolAttr(true)); + fnDecl.getBody().cloneInto(&(newFnDecl.getBody()), mapper); mapper.map(fnDecl.getOperation(), newFnDecl.getOperation()); fnDecl = newFnDecl; } diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/CMakeLists.txt b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/CMakeLists.txt index a467ce00e..a1c084b2d 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/CMakeLists.txt +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/CMakeLists.txt @@ -70,6 +70,7 @@ iree_cc_library( "AMDAIEDmaToCircularDma.cpp" "AMDAIEDmaUtils.cpp" "AMDAIEFlattenLogicalObjectFifo.cpp" + "AMDAIEFunctionOutlining.cpp" "AMDAIEFuseConsumerIntoLoop.cpp" "AMDAIEFuseFillIntoForall.cpp" "AMDAIEFusePackIntoLoop.cpp" diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/PassDetail.h b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/PassDetail.h index 4cd5586f0..76ac15d9e 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/PassDetail.h +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/PassDetail.h @@ -47,6 +47,7 @@ namespace mlir::iree_compiler::AMDAIE { #define GEN_PASS_DEF_AMDAIEDMALOOPSUBSUMPTION #define GEN_PASS_DEF_AMDAIEDMATOCIRCULARDMA #define GEN_PASS_DEF_AMDAIEFLATTENLOGICALOBJECTFIFO +#define GEN_PASS_DEF_AMDAIEFUNCTIONOUTLINING #define GEN_PASS_DEF_AMDAIEFUSECONSUMERINTOLOOP #define GEN_PASS_DEF_AMDAIEFUSEFILLINTOFORALL #define GEN_PASS_DEF_AMDAIEFUSEPACKINTOLOOP diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.cpp index 4bc7c8bc4..7c05b1f4e 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.cpp @@ -284,9 +284,6 @@ void addPackPeelBasedPassPipeline(OpPassManager &funcPassManager, funcPassManager.addPass(createAMDAIELowerToUKernelsPass(options)); } - // Vectorization passes - appendVectorizationToPipeline(funcPassManager, enableVectorizationPasses); - // Comprehensive bufferization addAMDAIEBufferizePasses(funcPassManager, useTilePipeline); funcPassManager.addPass(createHoistStaticallyBoundAllocationsPass()); @@ -477,11 +474,6 @@ void addConvDecomposePassPipeline(OpPassManager &funcPassManager, LinalgFoldUnitExtentDimsPassOptions opts; opts.useRankReducingSlices = true; funcPassManager.addPass(mlir::createLinalgFoldUnitExtentDimsPass(opts)); - - // Vectorization passes - // FIXME(newling) https://github.com/nod-ai/iree-amd-aie/issues/820 - enableVectorizationPasses = false; - appendVectorizationToPipeline(funcPassManager, enableVectorizationPasses); funcPassManager.addPass(createCanonicalizerPass()); // Comprehensive bufferization @@ -520,10 +512,12 @@ void buildAMDAIETransformPassPipeline( modulePassManager.addPass(createLowerUKernelOpsToCallsPass()); if (useLowerToAIEPipeline == LowerToAIEPassPipeline::ObjectFifo) { addAMDAIEObjectFifoLoweringPasses(modulePassManager, enablePacketFlow, - useTilePipeline); + useTilePipeline, + enableVectorizationPasses); } else if (useLowerToAIEPipeline == LowerToAIEPassPipeline::AIR) { addMLIRAIRLoweringPasses(modulePassManager, device, useTilePipeline, - matmulElementwiseFusion); + matmulElementwiseFusion, + enableVectorizationPasses); } else { assert( false && @@ -539,10 +533,10 @@ void buildAMDAIETransformPassPipeline( }); } - void addAMDAIEObjectFifoLoweringPasses(OpPassManager &passManager, bool enablePacketFlow, - TilePassPipeline useTilePipeline) { + TilePassPipeline useTilePipeline, + bool enableVectorizationPasses) { passManager.addPass(createEraseHALDescriptorTypeFromMemRefPass()); passManager.addPass(memref::createFoldMemRefAliasOpsPass()); @@ -565,6 +559,20 @@ void addAMDAIEObjectFifoLoweringPasses(OpPassManager &passManager, passManager.addPass(createAMDAIENormalizeLoopBoundsPass()); passManager.addPass(createAMDAIEInsertCoresPass()); + + passManager.addPass(createAMDAIEFunctionOutliningPass()); + + { + // Vectorization passes + OpPassManager &funcPassManager = passManager.nest(); + // FIXME(newling) https://github.com/nod-ai/iree-amd-aie/issues/820 + enableVectorizationPasses = + (useTilePipeline == TilePassPipeline::ConvDecomposePipeline) + ? false + : enableVectorizationPasses; + appendVectorizationToPipeline(funcPassManager, enableVectorizationPasses); + } + passManager.addPass(createAMDAIELocalizeLogicalObjectFifoPass()); passManager.addPass(createCSEPass()); @@ -670,10 +678,21 @@ void addMLIRAIELoweringPasses(OpPassManager &pm) { // for details. void addMLIRAIRLoweringPasses(OpPassManager &passManager, AMDAIEDevice device, TilePassPipeline useTilePipeline, - bool matmulElementwiseFusion) { + bool matmulElementwiseFusion, + bool enableVectorizationPasses) { // Add passes for preparing for lowering to MLIR-AIR passManager.addPass(createEraseHALDescriptorTypeFromMemRefPass()); passManager.addPass(memref::createFoldMemRefAliasOpsPass()); + { + // Vectorization passes + OpPassManager &funcPassManager = passManager.nest(); + // FIXME(newling) https://github.com/nod-ai/iree-amd-aie/issues/820 + enableVectorizationPasses = + (useTilePipeline == TilePassPipeline::ConvDecomposePipeline) + ? false + : enableVectorizationPasses; + appendVectorizationToPipeline(funcPassManager, enableVectorizationPasses); + } passManager.addPass(createAMDAIEBridgeToAIRPass()); // Running canonicalization for all pipelines here results in failures. @@ -838,8 +857,6 @@ void addMLIRAIRLoweringPasses(OpPassManager &passManager, AMDAIEDevice device, addMLIRAIELoweringPasses(passManager); } - - // NOTE: this runs on the top-level program module containing all hal.executable // ops. void buildAMDAIELinkingPassPipeline(OpPassManager &passManager) { diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.h b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.h index df670e19f..99cbbba2f 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.h +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.h @@ -17,13 +17,15 @@ namespace mlir::iree_compiler::AMDAIE { /// Add passes to lower to AIE objectFifos. void addAMDAIEObjectFifoLoweringPasses(OpPassManager &passManager, bool enablePacketFlow, - TilePassPipeline useTilePipeline); + TilePassPipeline useTilePipeline, + bool enableVectorizationPasses); /// Add passes to lower from MLIR-AIR through AIE. This is /// currently the default passes used for lowering after IREEs tiling. void addMLIRAIRLoweringPasses(OpPassManager &passManager, AMDAIEDevice device, TilePassPipeline useTilePipeline, - bool matmulElementwiseFusion); + bool matmulElementwiseFusion, + bool enableVectorizationPasses); /// Add lowering passes from MLIR-AIE. This is /// currently the default passes used for lowering from AIE dialect. @@ -161,6 +163,9 @@ std::unique_ptr createAMDAIEDmaToCircularDmaPass(); /// Create a pass to flatten the logical objectFifos. std::unique_ptr createAMDAIEFlattenLogicalObjectFifoPass(); +/// Create a pass for function outlining. +std::unique_ptr createAMDAIEFunctionOutliningPass(); + /// Create a pass to fuse the consumer op into the innermost last scf loop. std::unique_ptr createAMDAIEFuseConsumerIntoLoopPass( AMDAIEFuseConsumerIntoLoopOptions options = {}); diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.td b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.td index 7c8364fed..cc2f3a34c 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.td +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.td @@ -243,6 +243,12 @@ def AMDAIEFlattenLogicalObjectFifo : let constructor = "mlir::iree_compiler::AMDAIE::createAMDAIEFlattenLogicalObjectFifoPass()"; } +def AMDAIEFunctionOutlining : + Pass<"iree-amdaie-function-outlining", "ModuleOp"> { + let summary = "Function outlining"; + let constructor = "mlir::iree_compiler::AMDAIE::createAMDAIEFunctionOutliningPass()"; +} + def AMDAIEFuseConsumerIntoLoop : InterfacePass<"iree-amdaie-fuse-consumer-into-loop", "mlir::FunctionOpInterface"> { let summary = "Fuse the consumer operation into the innermost last scf loop."; diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/CMakeLists.txt b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/CMakeLists.txt index 570e66b83..31ed972f2 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/CMakeLists.txt +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/CMakeLists.txt @@ -35,6 +35,7 @@ iree_lit_test_suite( "dma_loop_subsumption.mlir" "dma_to_circular_dma.mlir" "flatten_logical_objectfifo.mlir" + "function_outlining.mlir" "fuse_consumer_into_loop_scf_for.mlir" "fuse_consumer_into_loop_scf_forall.mlir" "fuse_fill_into_forall.mlir" diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/function_outlining.mlir b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/function_outlining.mlir new file mode 100644 index 000000000..c243f3dd4 --- /dev/null +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/function_outlining.mlir @@ -0,0 +1,91 @@ +// RUN: iree-opt --split-input-file --iree-amdaie-function-outlining --verify-diagnostics --split-input-file %s | FileCheck %s + +// CHECK-LABEL: func.func private @generic_matmul_outlined +// CHECK-SAME: (%[[LHS:.*]]: memref<1x1x4x8x4x8xbf16>, +// CHECK-SAME: %[[RHS:.*]]: memref<1x1x8x4x8x4xbf16>, +// CHECK-SAME: %[[OUT:.*]]: memref<1x1x8x8x4x4xf32>) { +// CHECK: linalg.generic +// CHECK-SAME: ins(%[[LHS]], %[[RHS]] : +// CHECK-SAME: outs(%[[OUT]] : +// CHECK: return +// CHECK: } +// CHECK-LABEL: func.func @matmul_example +// CHECK-SAME: (%[[A:.*]]: memref<1x1x4x8x4x8xbf16>, +// CHECK-SAME: %[[B:.*]]: memref<1x1x8x4x8x4xbf16>, +// CHECK-SAME: %[[C:.*]]: memref<1x1x8x8x4x4xf32>) { +// CHECK: amdaie.core +// CHECK: func.call @generic_matmul_outlined(%[[A]], %[[B]], %[[C]]) +// CHECK-NOT: linalg.generic +// CHECK: amdaie.end +// CHECK: } +// CHECK: return +// CHECK: } +func.func @matmul_example(%A: memref<1x1x4x8x4x8xbf16>, %B: memref<1x1x8x4x8x4xbf16>, %C: memref<1x1x8x8x4x4xf32>) { + %c2 = arith.constant 2 : index + %c1 = arith.constant 1 : index + %tile = amdaie.tile(%c1, %c2) + %0 = amdaie.core(%tile, in : [], out : []) { + linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, + affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, + affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)> + ], + iterator_types = ["parallel", "parallel", "reduction", + "parallel", "parallel", "reduction", + "parallel", "parallel", "reduction" + ] + } ins(%A, %B : memref<1x1x4x8x4x8xbf16>, memref<1x1x8x4x8x4xbf16>) + outs(%C : memref<1x1x8x8x4x4xf32>) { + ^bb0(%in: bf16, %in_17: bf16, %out: f32): + %1 = arith.extf %in : bf16 to f32 + %2 = arith.extf %in_17 : bf16 to f32 + %3 = arith.mulf %1, %2 : f32 + %4 = arith.addf %out, %3 : f32 + linalg.yield %4 : f32 + } + amdaie.end + } + return +} + +// ----- + +// CHECK-LABEL: func.func private @generic_elementwise_outlined +// CHECK-SAME: (%[[INPUT:.*]]: memref<1x1x8x8x4x4xf32>, +// CHECK-SAME: %[[OUTPUT:.*]]: memref<1x1x8x8x4x4xbf16>) { +// CHECK: linalg.generic +// CHECK-SAME: ins(%[[INPUT]] : +// CHECK-SAME: outs(%[[OUTPUT]] : +// CHECK: return +// CHECK: } +// CHECK-LABEL: func.func @elemwise_example +// CHECK-SAME: (%[[A:.*]]: memref<1x1x8x8x4x4xf32>, +// CHECK-SAME: %[[C:.*]]: memref<1x1x8x8x4x4xbf16>) { +// CHECK: amdaie.core +// CHECK: func.call @generic_elementwise_outlined(%[[A]], %[[C]]) +// CHECK-NOT: linalg.generic +// CHECK: amdaie.end +// CHECK: } +// CHECK: return +// CHECK: } +func.func @elemwise_example(%A: memref<1x1x8x8x4x4xf32>, %C: memref<1x1x8x8x4x4xbf16>) { + %c2 = arith.constant 2 : index + %c1 = arith.constant 1 : index + %tile = amdaie.tile(%c1, %c2) + %0 = amdaie.core(%tile, in : [], out : []) { + linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], + iterator_types = ["parallel", "parallel", "parallel", + "parallel", "parallel", "parallel" + ] + } ins(%A : memref<1x1x8x8x4x4xf32>) + outs(%C : memref<1x1x8x8x4x4xbf16>) { + ^bb0(%in: f32, %out: bf16): + %1 = arith.truncf %in : f32 to bf16 + linalg.yield %1 : bf16 + } + amdaie.end + } + return +}