Skip to content

Commit

Permalink
Change calculation of reassociation indicies in ConvertConvToChannels…
Browse files Browse the repository at this point in the history
…Last.cpp (#17668)

- Added `GreedyRewriteConfig` set to `kNoLimit` since the patterns were
failing to converge within the 10 iterations
- Changed the way re-association indices are calculated for
`GeneralizeOuterUnitDimsPackOps`. (There might be a helper function
somewhere for this but i couldn't find one)




#### Before (verification error)
```mlir
%30042 = "tensor.expand_shape"(%30041) 
	<{reassociation = [[0, 1], [2, 3], [4], [5]], 
		static_output_shape = array<i64: 1, 1, 3, 3, 320, 4>}> 
	: (tensor<3x3x320x4xi8>) -> tensor<1x1x3x3x320x4xi8>
```

#### After
```mlir
%30042 = "tensor.expand_shape"(%30041) 
	<{reassociation = [[0, 1, 2], [3], [4], [5]], 
		static_output_shape = array<i64: 1, 1, 3, 3, 320, 4>}> 
	: (tensor<3x3x320x4xi8>) -> tensor<1x1x3x3x320x4xi8>
```



#17643

---------

Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu>
  • Loading branch information
IanWood1 authored Jun 13, 2024
1 parent 5f07787 commit 045bf32
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Preprocessing/Common/Passes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
Expand Down Expand Up @@ -491,10 +492,10 @@ struct ConvertLinalgConvOp : OpInterfaceRewritePattern<linalg::LinalgOp> {
// map = [[0], [1, 2], [3], [4, 5]]
template <typename SetTy>
static SmallVector<ReassociationIndices>
getTilingReassociationMap(int64_t rank, SetTy innerDims) {
getTilingReassociationMap(const int64_t rank, SetTy innerDims) {
SmallVector<ReassociationIndices> map;
int64_t nTiled = 0;
for (int64_t i = 0, e = rank; i < e; i++) {
for (int64_t i = 0; i < rank; i++) {
if (innerDims.contains(i)) {
map.push_back({i + nTiled++, i + nTiled});
continue;
Expand Down Expand Up @@ -570,10 +571,18 @@ class GeneralizeOuterUnitDimsPackOp final
rewriter
.create<linalg::TransposeOp>(loc, packOp.getSource(), empty, perm)
.getResult()[0];

// Expand the unit dimensions for the result of the pack.
SmallVector<ReassociationIndices> reassocationIndices;
int64_t nTiled = 0;
for (int64_t srcIdx = 0; srcIdx < srcRank; srcIdx++) {
reassocationIndices.push_back({srcIdx + nTiled});
while (innerDims.contains(srcIdx + nTiled))
reassocationIndices.back().push_back(srcIdx + ++nTiled);
}

rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
packOp, destType, transposed,
getTilingReassociationMap(srcRank, innerDims));
packOp, destType, transposed, reassocationIndices);
return success();
}
};
Expand Down Expand Up @@ -638,11 +647,8 @@ class GeneralizeOuterUnitDimsUnPackOp final
auto collapse = rewriter.create<tensor::CollapseShapeOp>(
loc, collapsedType, unpackOp.getSource(),
getTilingReassociationMap(destType.getRank(), innerDims));
rewriter
.replaceOpWithNewOp<linalg::TransposeOp>(unpackOp, collapse,
unpackOp.getDest(),
invertPermutationVector(perm))
.getResult()[0];
rewriter.replaceOpWithNewOp<linalg::TransposeOp>(
unpackOp, collapse, unpackOp.getDest(), invertPermutationVector(perm));
return success();
}
};
Expand Down Expand Up @@ -677,9 +683,12 @@ class ConvertConvToChannelsLastPass
// padding.
{
RewritePatternSet patterns(context);
GreedyRewriteConfig config;
config.maxIterations = GreedyRewriteConfig::kNoLimit;
linalg::populateDataLayoutPropagationPatterns(
patterns, [](Operation *op) { return true; });
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) {
if (failed(
applyPatternsAndFoldGreedily(op, std::move(patterns), config))) {
return signalPassFailure();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,22 @@ util.func @mmt_no_transpose(%arg0: tensor<2048x1280xf16>, %arg1: tensor<1280x128
// TILE16-LABEL: @mmt_no_transpose
// TILE16-NOT: linalg.generic
// TILE16: linalg.matmul_transpose_b


// -----

util.func @test_unit_dims_pack(%arg0: tensor<10x20x5xf32>) -> tensor<1x1x5x20x10xf32> {
%dst = tensor.empty() : tensor<1x1x5x20x10xf32>
%packed = tensor.pack %arg0 inner_dims_pos = [1, 0] inner_tiles = [20, 10]
into %dst : tensor<10x20x5xf32> -> tensor<1x1x5x20x10xf32>

util.return %packed : tensor<1x1x5x20x10xf32>
}

// CHECK-LABEL: @test_unit_dims_pack
// CHECK: %[[TRANSPOSED:.+]] = linalg.transpose ins(%[[ARG0:.+]] : tensor<10x20x5xf32>)
// CHECK-SAME: outs(%[[DST:.+]] : tensor<5x20x10xf32>) permutation = [2, 1, 0]
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape
// CHECK-SAME: [0, 1, 2], [3], [4]
// CHECK-SAME: tensor<5x20x10xf32> into tensor<1x1x5x20x10xf32>
// CHECK: util.return %[[EXPANDED]] : tensor<1x1x5x20x10xf32>

0 comments on commit 045bf32

Please sign in to comment.