Skip to content

[MLIR][linalg] DecomposeOuterUnitDimsPackOpPattern not checking trailing dimension for tiling #145861

Closed
@rYm-A

Description

@rYm-A

Compiling an ML model with IREE I found that an incorrect linalg.transpose op is generated after iree-codegen-decompose-pack-unpack-ops.

The linalg.pack I'm attempting to lower before iree-codegen-decompose-pack-unpack-ops is:

%pack = linalg.pack %extracted_slice_1 
	outer_dims_perm = [1, 0, 2] 
	inner_dims_pos = [0, 2] 
	inner_tiles = [8, 1] 
	into %extracted_slice_2 
	{lowering_config = #iree_codegen.lowering_config<
		tile_sizes = [[1, 48, 64], [1, 1, 1]]>} 
	: tensor<8x1x1xf32> -> tensor<1x1x1x8x1xf32>

And the result is:

%18 = "tensor.empty"() : () -> tensor<1x1x1xf32>
%19 = "linalg.transpose"(%16, %18) <{permutation = array<i64: 0, 0, 2>}> ({
        ^bb0(%arg7: f32, %arg8: f32):
          "linalg.yield"(%arg7) : (f32) -> ()
        }) : (tensor<8x1x1xf32>, tensor<1x1x1xf32>) -> tensor<1x1x1xf32>

Note that the permutation array is incorrect, what generates an assertion error:

Assertion permutationMap.isPermutation() && "Invalid permutation vector"' failed.

Checking this PR by @banach-space I found that this pattern shouldn't be matched by GeneralizeOuterUnitDimsPackOpPattern, since the inner_dim_pos of this linalg.pack op aren't the last of the source, which was assumed for this PR.

The check in https://github.com/llvm/llvm-project/blame/237b8de2c0d9ee50c6a744e95c0706c8cdea70e1/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp#L1183 seems not to catch this particular example since the -1 at the end may be accepting the N + 1 trailing dims, and not the N trailing dims. Could you confirm if this is intended?

Changing the check to return dimPos >= (srcRank - numTiles); will cause the pattern in this particular example to fail, and as a result, this linalg.op will eventually be lowered to:

%expanded = tensor.expand_shape %arg0 [[0, 1], [2], [3, 4]] output_shape [1, 8, 1, 1, 1] : tensor<8x1x1xf32> into tensor<1x8x1x1x1xf32>
%transposed = linalg.transpose ins(%expanded : tensor<1x8x1x1x1xf32>) outs(%arg1 : tensor<1x1x1x8x1xf32>) permutation = [2, 0, 3, 1, 4]  {lowering_config = #config}

How to reproduce

IREE v3.5.0.

iree-opt packOp.mlir \
--pass-pipeline="builtin.module(func.func(iree-codegen-decompose-pack-unpack-ops))"  \
--debug \
--mlir-disable-threading 

Trace:


Stack dump without symbol names (ensure you have llvm-symbolizer in your PATH or set the environment var `LLVM_SYMBOLIZER_PATH` to point to it):
0  libIREECompiler.so 0x00007fec4b0a85b8 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) + 40
1  libIREECompiler.so 0x00007fec4b0a635e llvm::sys::RunSignalHandlers() + 238
2  libIREECompiler.so 0x00007fec4b0a8cb6
3  libc.so.6          0x00007fec454e6520
4  libIREECompiler.so 0x00007fec4b0d50b3 mlir::AffineMap::getContext() const + 3
5  libIREECompiler.so 0x00007fec4b102bc4 mlir::AffineMapAttr::get(mlir::AffineMap) + 20
6  libIREECompiler.so 0x00007fec4b0ff979 mlir::Builder::getAffineMapArrayAttr(llvm::ArrayRef<mlir::AffineMap>) + 137
7  libIREECompiler.so 0x00007fec4f9b71fb mlir::linalg::TransposeOp::getIndexingMaps() + 475
8  libIREECompiler.so 0x00007fec4cb25b76
9  libIREECompiler.so 0x00007fec4fc51f67
10 libIREECompiler.so 0x00007fec4f99da41 mlir::linalg::LinalgOp::getIndexingMapsArray() + 17
11 libIREECompiler.so 0x00007fec4fb1cb0d
12 libIREECompiler.so 0x00007fec4ff8e3cd
13 libIREECompiler.so 0x00007fec4ff8b664 mlir::PatternApplicator::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&, llvm::function_ref<bool (mlir::Pattern const&)>, llvm::function_ref<void (mlir::Pattern const&)>, llvm::function_ref<llvm::LogicalResult (mlir::Pattern const&)>) + 820
14 libIREECompiler.so 0x00007fec4ff729e7
15 libIREECompiler.so 0x00007fec4ff707bb mlir::applyPatternsGreedily(mlir::Region&, mlir::FrozenRewritePatternSet const&, mlir::GreedyRewriteConfig, bool*) + 1819
16 libIREECompiler.so 0x00007fec4dc12c0b
17 libIREECompiler.so 0x00007fec4dc1195c
18 libIREECompiler.so 0x00007fec4b31e5cb mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) + 635
19 libIREECompiler.so 0x00007fec4b31f1b9 mlir::detail::OpToOpPassAdaptor::runPipeline(mlir::OpPassManager&, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) + 329
20 libIREECompiler.so 0x00007fec4b321907 mlir::detail::OpToOpPassAdaptor::runOnOperationImpl(bool) + 2311
21 libIREECompiler.so 0x00007fec4b31ea8c mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) + 1852
22 libIREECompiler.so 0x00007fec4b3224db mlir::PassManager::run(mlir::Operation*) + 1531
23 libIREECompiler.so 0x00007fec4b3150be
24 libIREECompiler.so 0x00007fec4b314d3d
25 libIREECompiler.so 0x00007fec4b316c92 mlir::splitAndProcessBuffer(std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, llvm::function_ref<llvm::LogicalResult (std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&)>, llvm::raw_ostream&, llvm::StringRef, llvm::StringRef) + 818
26 libIREECompiler.so 0x00007fec4b30f362 mlir::MlirOptMain(llvm::raw_ostream&, std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&) + 226
27 libIREECompiler.so 0x00007fec4b008e1d ireeOptRunMain + 2013
28 libc.so.6          0x00007fec454cdd90
29 libc.so.6          0x00007fec454cde40 __libc_start_main + 128
30 iree-opt           0x000000000020169e

Metadata

Metadata

Assignees

No one assigned

    Type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions