From c84cca03034b61c97a081ebce47cf31a06255356 Mon Sep 17 00:00:00 2001 From: James Newling Date: Tue, 8 Oct 2024 17:41:35 -0700 Subject: [PATCH] [Convolution] Packing + objectFifo, initial support (#789) This PR switches all numerical convolution tests to use the objectFifo pipeline. With respect to the new tiling strategy: 1) A single **column** is currently used. Targeting multiple columns results in ` error: 'aie.memtile_dma' op could not find and assign a valid BD id`. This will will be investigated as follow-up work: https://github.com/nod-ai/iree-amd-aie/issues/821 2) There is no longer interleaving of compute and L2->L1 data movement, which means https://github.com/nod-ai/iree-amd-aie/issues/619 becomes low priority / obsolete 3) L3->L2, L2->L3 still uses padding. But L2->L1, L1->L2 uses packing. 4) Channel-first convolution is completely unsupported, we expect high level transforms to convert to channel last before reaching our backend. 5) Vectorization is not currently enabled, due to issues with alignment. See follow-up task https://github.com/nod-ai/iree-amd-aie/issues/820. This is functionally ok for now, as peano can scalarize code for all data types. --- build_tools/ci/cpu_comparison/run.py | 4 +- .../iree-amd-aie/Test/samples/CMakeLists.txt | 2 +- ...e.mlir => conv2d_nhwc_objectfifo_e2e.mlir} | 2 +- .../Transforms/AMDAIEPackAndTranspose.cpp | 17 ++- .../Transforms/KernelDispatch.cpp | 122 +++++++++++++----- .../iree-amd-aie/Transforms/Passes.cpp | 60 ++++----- .../test/lowering_strategy_conv.mlir | 47 +++---- 7 files changed, 145 insertions(+), 109 deletions(-) rename compiler/plugins/target/AMD-AIE/iree-amd-aie/Test/samples/{conv2d_nhwc_air_e2e.mlir => conv2d_nhwc_objectfifo_e2e.mlir} (94%) diff --git a/build_tools/ci/cpu_comparison/run.py b/build_tools/ci/cpu_comparison/run.py index 4be273f44..dc63990dc 100755 --- a/build_tools/ci/cpu_comparison/run.py +++ b/build_tools/ci/cpu_comparison/run.py @@ -641,7 +641,7 @@ def run(self, config): config, test_name, tile_pipeline="conv-decompose", - lower_to_aie_pipeline="air", + lower_to_aie_pipeline="objectFifo", n_repeats=n_conv_repeats, ) @@ -661,7 +661,7 @@ def run(self, config): config, test_files_dir / f"{name}.mlir", tile_pipeline="conv-decompose", - lower_to_aie_pipeline="air", + lower_to_aie_pipeline="objectFifo", n_repeats=n_conv_repeats, ) diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Test/samples/CMakeLists.txt b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Test/samples/CMakeLists.txt index 618409664..0cf6bc133 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Test/samples/CMakeLists.txt +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Test/samples/CMakeLists.txt @@ -8,7 +8,7 @@ iree_lit_test_suite( NAME lit SRCS - "conv2d_nhwc_air_e2e.mlir" + "conv2d_nhwc_objectfifo_e2e.mlir" "matmul_elementwise_pack_peel_air_e2e.mlir" "matmul_pack_peel_air_e2e.mlir" "matmul_pack_peel_objectfifo.mlir" diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Test/samples/conv2d_nhwc_air_e2e.mlir b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Test/samples/conv2d_nhwc_objectfifo_e2e.mlir similarity index 94% rename from compiler/plugins/target/AMD-AIE/iree-amd-aie/Test/samples/conv2d_nhwc_air_e2e.mlir rename to compiler/plugins/target/AMD-AIE/iree-amd-aie/Test/samples/conv2d_nhwc_objectfifo_e2e.mlir index 2b005150a..171667038 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Test/samples/conv2d_nhwc_air_e2e.mlir +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Test/samples/conv2d_nhwc_objectfifo_e2e.mlir @@ -1,4 +1,4 @@ -// RUN: iree-compile --iree-hal-target-backends=amd-aie --compile-to=executable-sources %s | iree-opt --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(iree-hal-translate-target-executable-variants{target=amd-aie})))" --iree-amdaie-tile-pipeline=conv-decompose --iree-amdaie-lower-to-aie-pipeline=air --split-input-file | FileCheck %s +// RUN: iree-compile --iree-hal-target-backends=amd-aie --compile-to=executable-sources %s | iree-opt --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(iree-hal-translate-target-executable-variants{target=amd-aie})))" --iree-amdaie-tile-pipeline=conv-decompose --iree-amdaie-lower-to-aie-pipeline=objectFifo --split-input-file | FileCheck %s func.func @conv_2d_nhwc_hwcf(%arg0: tensor<2x14x14x32xi32>, %arg1: tensor<3x3x32x64xi32>) -> tensor<2x12x12x64xi32> { %cst = arith.constant 0 : i32 diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEPackAndTranspose.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEPackAndTranspose.cpp index 8f846110d..62544391e 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEPackAndTranspose.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEPackAndTranspose.cpp @@ -20,7 +20,6 @@ namespace { static FailureOr applyPackOnLinalgOp( RewriterBase &rewriter, linalg::LinalgOp op, SmallVector packedSizes) { - // Fail on mismatched number of pack sizes. if (packedSizes.size() != op.getNumLoops()) { op->emitOpError( "requires number of packed sizes match the number of loops (") @@ -29,12 +28,14 @@ static FailureOr applyPackOnLinalgOp( } rewriter.setInsertionPoint(op); - FailureOr packResult = + FailureOr maybePackResult = linalg::pack(rewriter, op, packedSizes); - if (failed(packResult)) { + if (failed(maybePackResult)) { op->emitOpError("failed to pack the operation"); return failure(); } + + linalg::PackResult packResult = maybePackResult.value(); return packResult; } @@ -60,7 +61,8 @@ void AMDAIEPackAndTransposePass::runOnOperation() { // Find the linalg op for packing, currently only consider contraction ops linalg::LinalgOp linalgOp; funcOp->walk([&](linalg::LinalgOp op) { - if (linalg::isaContractionOpInterface(op)) { + if (linalg::isaContractionOpInterface(op) || + isa(op.getOperation())) { linalgOp = op; return WalkResult::interrupt(); } @@ -75,6 +77,7 @@ void AMDAIEPackAndTransposePass::runOnOperation() { // Step 1. Before packing the operation, we will prefetch the lowering and // packing config. auto config = getLoweringConfig(linalgOp); + auto packingConfig = getPackingConfig(linalgOp); if (!config || !packingConfig) { @@ -87,6 +90,12 @@ void AMDAIEPackAndTransposePass::runOnOperation() { // Extract packing config from the `linalgOp`. PackingConfigPackingLevelAttr packCfg = packingConfig.getPackingConfigVals(packLevel); + + if (!packCfg) { + funcOp->emitOpError("failed to get pack config for pack level ") + << packLevel; + return signalPassFailure(); + } SmallVector packedSizes = getAsIndexOpFoldResult(context, packCfg.getPackedSizes()); diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/KernelDispatch.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/KernelDispatch.cpp index ca657b796..398298120 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/KernelDispatch.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/KernelDispatch.cpp @@ -472,54 +472,83 @@ static LogicalResult setRootConfigForPadPackPipeline( static LogicalResult setRootConfigForConvDecomposePipeline( mlir::FunctionOpInterface entryPointFn, linalg::LinalgOp linalgOp) { + MLIRContext *context = entryPointFn.getContext(); + FailureOr> maybeInstructionSize = getMatmulInstructionSize(linalgOp); int64_t OW = 4; int64_t OC = 4; int64_t IC = 8; if (succeeded(maybeInstructionSize)) { - auto instructionSize = maybeInstructionSize.value(); - OW = instructionSize[0]; - OC = instructionSize[1]; - IC = instructionSize[2]; + auto [m, n, k] = maybeInstructionSize.value(); + OW = m; + OC = n; + IC = k; } + SmallVector transposePackIndices{0, 1, 2}; + SmallVector unpackEmpty{false, false, true}; + + // Convolution type specific vectors: + SmallVector> innerPerm; + SmallVector> outerPerm; SmallVector tileSizeLevel0; SmallVector tileSizeLevel1; SmallVector tileSizeLevel2; - // Note: some of the tiling dimensions are hardcoded for now. - if (isa(linalgOp) || - isa(linalgOp)) { - // conv_2d_nhwc_hwcf tiling dims: [N, OH, OW, OC, KH, KW, IC]. - tileSizeLevel0 = {0, 4, OW, OC, 0, 0, 0}; + SmallVector packingSizes; + + // [N, OH, OW, OC, KH, KW, IC]. + if (isa(linalgOp) || + isa(linalgOp)) { + // The goal is to pack the input image and kernel as follows, when moving + // from L2 to L1 (example where there are 32 input channels): + // Image: memref<1x3x6x32xbf16> -> memref<1x3x4x6x8xbf16> + // Kernel: memref<3x3x32x4xbf16> -> memref<3x3x4x1x8x4xbf16> + innerPerm = {{}, {{1, 0}}, {}}; + outerPerm = {{0, 1, 3, 2}, {}, {0, 1, 2, 3}}; + packingSizes = {0, 0, 0, OC, 0, 0, IC}; + // Target one column of 4 cores, each core processing a different + // output image row. TODO(newling) use 4x4 array. + // https://github.com/nod-ai/iree-amd-aie/issues/821 + tileSizeLevel0 = {1, 4, OW, OC, 0, 0, 0}; tileSizeLevel1 = {1, 1, OW, OC, 0, 0, 0}; - tileSizeLevel2 = {0, 0, 0, 0, 1, 1, IC}; - } else if (isa(linalgOp)) { - // conv_2d_nchw_fchw tiling dims: [N, OC, OH, OW, IC, KH, KW]. - tileSizeLevel0 = {0, OC, 4, OW, 0, 0, 0}; - tileSizeLevel1 = {1, OC, 1, OW, 0, 0, 0}; - tileSizeLevel2 = {0, 0, 0, 0, IC, 1, 1}; - } else if (isa(linalgOp)) { - // Notes: + // scf.for tiling of KH, KW, and (packed) IC dimensions: + tileSizeLevel2 = {0, 0, 0, 0, 1, 1, 1, 0, 0}; + } + + // [N, OC, OH, OW, IC, KH, KW] + else if (isa(linalgOp)) { + // The matmul reduction dimension is the input channel (IC) dimension. + // For Conv2DNhwcHwcfOp, this dimension is already the inner-most dimension + // of the input image, and the penultimate dimension of the kernel -- + // exactly what we want. For Conv2DNchwFchwOp, can the tensor dimensions be + // permuted in DMA to get them in the correct positions? For the image + // tensor, only if H*W is a nice power of 2 (DMA constraint). For kernels, + // it requires h*w is a nice power of 2 -- unlikely, we typically have + // h=w=3. The dimension permutations will therefore often therefore need to + // be done on the core. We leave this for future work, the expectation for + // now is that models have been transformed at a high level to avoid + // channel-first convolutions. + return linalgOp.emitError( + "Only channel-last convolution supported currently."); + } + + // [N, OH, OW, C, KW, HW] + else if (isa(linalgOp)) { + // Notes // ===== - // - // An inherent property of depthwise convolutions is that they cannot be - // expressed in terms of matmuls, unlike the above (dense) conv-2ds. The - // tile sizes we choose below are therefore not constrained by the AIE - // matmul instructions. + // A property of depthwise convolution is that it can't be expressed in + // terms of matmul, unlike the above (dense) conv-2ds. The tile sizes we + // choose below are therefore not constrained by AIE matmul instructions. // // The logic is currently fragile, and there are no guardrails: there are // no checks that the data tiles are not too large, or that the input // dimensions are perfectly tiled by the hard-coded tile dimensions below. // These will be done as a follow-up task. - // - // - // Below we target a 4x4 array of AIE cores. auto getElementType = [](Value v) { return cast(v.getType()).getElementType(); }; const uint16_t OW_0 = 4; - const uint16_t OH_0 = 4; const uint16_t OH_1 = 1; auto operandType = getElementType(linalgOp->getOperand(0)); @@ -530,8 +559,8 @@ static LogicalResult setRootConfigForConvDecomposePipeline( OC_0 = maybeMacNumElements.value(); } // If the operand type has fewer than 32-bits, we really should be able to - // get a mac-width for it Bail because we didn't, and there's probably just - // something missing in the table. + // get a mac-width for it. Bail because we didn't, there's probably just + // something missing in a table. else if (operandType.getIntOrFloatBitWidth() < 32) { return linalgOp.emitError( "has an operand type with fewer than 32-bits, but no mac-width " @@ -539,17 +568,40 @@ static LogicalResult setRootConfigForConvDecomposePipeline( } const uint16_t OC_1 = OC_0 / 4; - - // depthwise_conv2d_nhwc_hwc tiling dims: - // [N, OH, OW, OC, KH,KW] - tileSizeLevel0 = {1, OH_0, OW_0, OC_0, 0, 0}; + packingSizes = {0, 0, 0, OC_1, 0, 0}; + innerPerm = {{}, {}, {}}; + outerPerm = {{0, 1, 2, 3}, {0, 1, 2}, {0, 1, 2, 3}}; + // Target one column of 4 cores, each core processing a different + // output image row. TODO(newling) use 4x4 array. + // https://github.com/nod-ai/iree-amd-aie/issues/821 + tileSizeLevel0 = {1, 4 * OH_1, OW_0, OC_1, 0, 0}; tileSizeLevel1 = {1, OH_1, OW_0, OC_1, 0, 0}; - tileSizeLevel2 = {0, 0, 0, 0, 1, 1}; - } else { - assert(false && "Support must be added for this convolution op"); + tileSizeLevel2 = {0, 0, 0, 0, 1, 1, 0}; } + + else { + return linalgOp.emitError( + "unrecognised convolution op, cannot set packing config. "); + } + + assert(!innerPerm.empty() && !outerPerm.empty() && !packingSizes.empty() && + !tileSizeLevel0.empty() && !tileSizeLevel1.empty() && + "not all vectors for initializing config are non-empty"); + + auto packingConfigLevel1Attr = getPackingConfigPackingLevelAttr( + context, packingSizes, transposePackIndices, unpackEmpty, innerPerm, + outerPerm); + SmallVector packingConfigLevelsVal{ + packingConfigLevel1Attr}; + + auto packingConfigLevels = + PackingConfigPackingLevelsAttr::get(context, packingConfigLevelsVal); + auto config = PackingConfigAttr::get(context, packingConfigLevels); + setPackingConfig(linalgOp, config); + TileSizesListType tileSizes = {tileSizeLevel0, tileSizeLevel1, tileSizeLevel2}; + return setOpConfigAndEntryPointFnTranslation( entryPointFn, linalgOp, tileSizes, IREE::Codegen::DispatchLoweringPassPipeline::Custom); diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.cpp index 75a40545c..b5044f592 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.cpp @@ -408,16 +408,20 @@ void addConvDecomposePassPipeline(OpPassManager &funcPassManager, TilingConfig &tilingConfig, bool enableVectorizationPasses, TilePassPipeline useTilePipeline) { + auto addCleanups = [&]() { + funcPassManager.addPass(createAMDAIECleanupPass()); + funcPassManager.addPass(createCanonicalizerPass()); + funcPassManager.addPass(createCSEPass()); + }; + // First level tiling using scf.forall { AMDAIETileAndFuseOptions tileFuseOptions; tileFuseOptions.tilingLevel = 0; tileFuseOptions.useSCFFor = false; funcPassManager.addPass(createAMDAIETileAndFusePass(tileFuseOptions)); + addCleanups(); } - funcPassManager.addPass(createAMDAIECleanupPass()); - funcPassManager.addPass(createCanonicalizerPass()); - funcPassManager.addPass(createCSEPass()); // Pad the linalg operation { @@ -441,67 +445,50 @@ void addConvDecomposePassPipeline(OpPassManager &funcPassManager, tileFuseOptions.tilingLevel = 1; tileFuseOptions.useSCFFor = false; funcPassManager.addPass(createAMDAIETileAndFusePass(tileFuseOptions)); + addCleanups(); } - funcPassManager.addPass(createAMDAIECleanupPass()); - funcPassManager.addPass(createCanonicalizerPass()); - funcPassManager.addPass(createCSEPass()); // Fuse fill op into the inner forall loop funcPassManager.addPass(createAMDAIEFuseFillIntoForallPass()); - funcPassManager.addPass(createCanonicalizerPass()); - // Pad the linalg operation + // Pack the linalg operation { - AMDAIEPadOptions padOptions; - padOptions.paddingLevel = 1; - funcPassManager.addPass(createAMDAIEPadPass(padOptions)); + AMDAIEPackAndTransposeOptions packOptions; + packOptions.packLevel = 0; + funcPassManager.addPass(createAMDAIEPackAndTransposePass(packOptions)); } - // Only promote the result to local memory + // Promote the inputs and results to local memory { AMDAIEBufferizeToAllocationOptions bufferizeOptions; bufferizeOptions.memorySpace = 2; - bufferizeOptions.bufferizeOperand = BufferizeOperand::Output; + bufferizeOptions.bufferizeOperand = BufferizeOperand::InputOutput; funcPassManager.addPass( createAMDAIEBufferizeToAllocationPass(bufferizeOptions)); + addCleanups(); } - // Tile the reduction dimension using scf.for { AMDAIETileAndFuseOptions tileFuseOptions; tileFuseOptions.tilingLevel = 2; tileFuseOptions.useSCFFor = true; funcPassManager.addPass(createAMDAIETileAndFusePass(tileFuseOptions)); - } - funcPassManager.addPass(createAMDAIECleanupPass()); - funcPassManager.addPass(createCanonicalizerPass()); - funcPassManager.addPass(createCSEPass()); - - // Pad the linalg operation - { - AMDAIEPadOptions padOptions; - padOptions.paddingLevel = 2; - funcPassManager.addPass(createAMDAIEPadPass(padOptions)); - } - - // Promote the inputs to local memory - { - AMDAIEBufferizeToAllocationOptions bufferizeOptions; - bufferizeOptions.memorySpace = 2; - bufferizeOptions.bufferizeOperand = BufferizeOperand::Input; - funcPassManager.addPass( - createAMDAIEBufferizeToAllocationPass(bufferizeOptions)); + addCleanups(); } - // Decompose Conv2d ops to Conv1d ops - funcPassManager.addPass(createDecomposeConvolutionToLowerDimOpsPass()); + LinalgFoldUnitExtentDimsPassOptions opts; + opts.useRankReducingSlices = true; + funcPassManager.addPass(mlir::createLinalgFoldUnitExtentDimsPass(opts)); // Vectorization passes + // FIXME(newling) https://github.com/nod-ai/iree-amd-aie/issues/820 + enableVectorizationPasses = false; appendVectorizationToPipeline(funcPassManager, enableVectorizationPasses); funcPassManager.addPass(createCanonicalizerPass()); // Comprehensive bufferization addAMDAIEBufferizePasses(funcPassManager, useTilePipeline); + funcPassManager.addPass(createHoistStaticallyBoundAllocationsPass()); } void buildAMDAIETransformPassPipeline( @@ -557,6 +544,7 @@ void addAMDAIEObjectFifoLoweringPasses(OpPassManager &passManager, bool enablePacketFlow) { passManager.addPass(createEraseHALDescriptorTypeFromMemRefPass()); passManager.addPass(memref::createFoldMemRefAliasOpsPass()); + passManager.addPass(createCanonicalizerPass()); passManager.addPass(createAMDAIEConvertToDmaPass()); passManager.addPass(createAMDAIENormalizeLoopBoundsPass()); @@ -582,6 +570,7 @@ void addAMDAIEObjectFifoLoweringPasses(OpPassManager &passManager, passManager.addPass(createAMDAIEAssignLogicalObjectFifoDepthPass()); passManager.addPass(createAMDAIEAccessToAcquireReleasePass()); passManager.addPass(createAMDAIENoneAccessToTemporaryBufferPass()); + passManager.addPass( createAMDAIEAssignConnectionTypesPass({enablePacketFlow})); passManager.addPass(createCSEPass()); @@ -612,6 +601,7 @@ void addAMDAIEObjectFifoLoweringPasses(OpPassManager &passManager, passManager.addPass(createCanonicalizerPass()); passManager.addPass(createAMDAIEObjFifoBufferizationPass()); + passManager.addPass(createAMDAIETemporaryAllocBufferizationPass()); passManager.addPass(createAMDAIEConnectionToFlowPass()); passManager.addPass(createAMDAIEAssignPacketIdsPass()); diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/lowering_strategy_conv.mlir b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/lowering_strategy_conv.mlir index ad7b127ec..d07c3d136 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/lowering_strategy_conv.mlir +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/lowering_strategy_conv.mlir @@ -1,32 +1,9 @@ // RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(iree-amdaie-lowering-strategy{use-pass-pipeline=conv-decompose})' %s | FileCheck %s -// CHECK{LITERAL}: #config = #iree_codegen.lowering_config -#pipeline_layout = #hal.pipeline.layout, - , - -]> -builtin.module { - func.func @conv_2d_nchw_fchw_2x64x12x12x32x3x3_i32() { - %cst = arith.constant 0 : i32 - %c0 = arith.constant 0 : index - %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> - %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> - %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor> - %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [2, 32, 14, 14], strides = [1, 1, 1, 1] : !flow.dispatch.tensor> -> tensor<2x32x14x14xi32> - %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, 0], sizes = [64, 32, 3, 3], strides = [1, 1, 1, 1] : !flow.dispatch.tensor> -> tensor<64x32x3x3xi32> - %5 = tensor.empty() : tensor<2x64x12x12xi32> - %6 = linalg.fill ins(%cst : i32) outs(%5 : tensor<2x64x12x12xi32>) -> tensor<2x64x12x12xi32> - %7 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%3, %4 : tensor<2x32x14x14xi32>, tensor<64x32x3x3xi32>) outs(%6 : tensor<2x64x12x12xi32>) -> tensor<2x64x12x12xi32> - // CHECK: linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, lowering_config = #config, strides = dense<1> : vector<2xi64>} - flow.dispatch.tensor.store %7, %2, offsets = [0, 0, 0, 0], sizes = [2, 64, 12, 12], strides = [1, 1, 1, 1] : tensor<2x64x12x12xi32> -> !flow.dispatch.tensor> - return - } -} -// ----- -// CHECK{LITERAL}: #iree_codegen.lowering_config +// CHECK{LITERAL}: #config = #iree_codegen.lowering_config +// CHECK{LITERAL}: #packingConfig = #amdaie.packing_config #pipeline_layout = #hal.pipeline.layout, , @@ -43,14 +20,17 @@ func.func @conv_static_dispatch_0_conv_2d_nhwc_hwcf_2x12x12x64x3x3x32_bf16xbf16x %5 = tensor.empty() : tensor<2x12x12x64xf32> %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<2x12x12x64xf32>) -> tensor<2x12x12x64xf32> %7 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%3, %4 : tensor<2x14x14x32xbf16>, tensor<3x3x32x64xbf16>) outs(%6 : tensor<2x12x12x64xf32>) -> tensor<2x12x12x64xf32> - // CHECK: linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : vector<2xi64>, lowering_config = #config, strides = dense<1> : vector<2xi64>} + // CHECK: linalg.conv_2d_nhwc_hwcf + // CHECK-SAME: lowering_config = #config, + // CHECK-SAME: packing_config = #packingConfig, flow.dispatch.tensor.store %7, %2, offsets = [0, 0, 0, 0], sizes = [2, 12, 12, 64], strides = [1, 1, 1, 1] : tensor<2x12x12x64xf32> -> !flow.dispatch.tensor> return } // ----- -// CHECK{LITERAL}: #config = #iree_codegen.lowering_config +// CHECK{LITERAL}: #config = #iree_codegen.lowering_config +// CHECK{LITERAL}: #packingConfig = #amdaie.packing_config #pipeline_layout = #hal.pipeline.layout, , @@ -67,15 +47,18 @@ func.func @conv_depthwise_channel_last_bf16(){ %5 = tensor.empty() : tensor<2x12x12x64xf32> %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<2x12x12x64xf32>) -> tensor<2x12x12x64xf32> %7 = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%3, %4 : tensor<2x14x14x64xbf16>, tensor<3x3x64xbf16>) outs(%6 : tensor<2x12x12x64xf32>) -> tensor<2x12x12x64xf32> - // CHECK: linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>, lowering_config = #config, strides = dense<1> : vector<2xi64>} + // CHECK: linalg.depthwise_conv_2d_nhwc_hwc + // CHECK-SAME: lowering_config = #config, + // CHECK-SAME: packing_config = #packingConfig, flow.dispatch.tensor.store %7, %2, offsets = [0, 0, 0, 0], sizes = [2, 12, 12, 64], strides = [1, 1, 1, 1] : tensor<2x12x12x64xf32> -> !flow.dispatch.tensor> return } // ----- -// Same test as above, but where the operand type is i8. In this case we expect OC tile size of 32 (not 16) at level 0, and 8 at levels 1 and 2. This is because of the instruction size of AIE. +// Same test as above, but where the operand type is i8. In this case we expect OC tile size 8 (not 4) at level 1. This is because of the instruction size of AIE. -// CHECK{LITERAL}: #config = #iree_codegen.lowering_config +// CHECK{LITERAL}: #config = #iree_codegen.lowering_config +// CHECK{LITERAL}: #packingConfig = #amdaie.packing_config #pipeline_layout = #hal.pipeline.layout, , @@ -92,7 +75,9 @@ func.func @conv_depthwise_channel_last_i8(){ %5 = tensor.empty() : tensor<2x12x12x64xi32> %6 = linalg.fill ins(%cst : i32) outs(%5 : tensor<2x12x12x64xi32>) -> tensor<2x12x12x64xi32> %7 = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%3, %4 : tensor<2x14x14x64xi8>, tensor<3x3x64xi8>) outs(%6 : tensor<2x12x12x64xi32>) -> tensor<2x12x12x64xi32> - // CHECK: linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>, lowering_config = #config, strides = dense<1> : vector<2xi64>} + // CHECK: linalg.depthwise_conv_2d_nhwc_hwc + // CHECK-SAME: lowering_config = #config, + // CHECK-SAME: packing_config = #packingConfig, flow.dispatch.tensor.store %7, %2, offsets = [0, 0, 0, 0], sizes = [2, 12, 12, 64], strides = [1, 1, 1, 1] : tensor<2x12x12x64xi32> -> !flow.dispatch.tensor> return }