-
Notifications
You must be signed in to change notification settings - Fork 637
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Flow] Move elementwise op fusion and bubble up expand shapes pattern…
…s into their own pass. (#17068) Currently these were run within the `FusionOfTensorOps` pass as a single fixed point. This makes it hard to debug what exactly is happening with fusion. Having them at separate pass boundaries allows for better debugability (especially on large models). This pass should mostly be an NFC and is the first step towards breaking up `FusionOfTensors` pass.
- Loading branch information
1 parent
3677fbc
commit f32a87c
Showing
10 changed files
with
407 additions
and
221 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
98 changes: 98 additions & 0 deletions
98
compiler/src/iree/compiler/Dialect/Flow/Transforms/BubbleUpExpandShapes.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
// Copyright 2024 The IREE Authors | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
//===--- BubbleExpandShapes.cpp --- Pass to propagate expand shapes op up -===// | ||
// | ||
// This pass propagates expand_shape operations up the program (and conversely) | ||
// sinks the collapse_shape operations down the program to get the elementwise | ||
// operations into higher dimensionality to get better fusion. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "iree/compiler/Dialect/Flow/Transforms/FusionUtils.h" | ||
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h" | ||
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h" | ||
#include "llvm/Support/Debug.h" | ||
#include "mlir/Dialect/Affine/IR/AffineOps.h" | ||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h" | ||
#include "mlir/Dialect/Tensor/Transforms/Transforms.h" | ||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
|
||
#define DEBUG_TYPE "iree-flow-bubble-up-expand-shapes" | ||
|
||
namespace mlir::iree_compiler::IREE::Flow { | ||
|
||
#define GEN_PASS_DEF_BUBBLEUPEXPANDSHAPESPASS | ||
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h.inc" | ||
|
||
namespace { | ||
|
||
class BubbleUpExpandShapesPass | ||
: public impl::BubbleUpExpandShapesPassBase<BubbleUpExpandShapesPass> { | ||
public: | ||
using Base::Base; | ||
|
||
void runOnOperation() override; | ||
}; | ||
|
||
} // namespace | ||
|
||
void BubbleUpExpandShapesPass::runOnOperation() { | ||
MLIRContext *context = &getContext(); | ||
|
||
RewritePatternSet bubbleExpandShapePatterns(context); | ||
linalg::ControlFusionFn bubbleUpExpansionControlFn = | ||
[](OpOperand *fusedOperand) { | ||
Operation *producer = fusedOperand->get().getDefiningOp(); | ||
Operation *consumer = fusedOperand->getOwner(); | ||
if (!isNonNullAndOutsideDispatch({producer, consumer})) { | ||
return false; | ||
} | ||
|
||
// Do not fuse by expand if consumer is dequant. | ||
if (isDequantizationLikeOp(consumer)) { | ||
return false; | ||
} | ||
|
||
// Do not fuse producer generic op if it has more than one user | ||
// or any reduction iterators. | ||
if (auto producerGenericOp = dyn_cast<linalg::GenericOp>(producer)) { | ||
return producerGenericOp->hasOneUse() && | ||
llvm::all_of(producerGenericOp.getIteratorTypesArray(), | ||
linalg::isParallelIterator); | ||
} | ||
|
||
// Do not fuse with any producer linalg named ops for now. | ||
if (isa<linalg::LinalgOp>(producer)) { | ||
return false; | ||
} | ||
|
||
// Do not fuse with consumer linalg named ops or reductions. | ||
if (auto consumerLinalgOp = dyn_cast<linalg::LinalgOp>(consumer)) { | ||
return isa<linalg::GenericOp>(consumerLinalgOp) && | ||
llvm::all_of(consumerLinalgOp.getIteratorTypesArray(), | ||
linalg::isParallelIterator); | ||
} | ||
// Fuse in all other cases. | ||
return true; | ||
}; | ||
linalg::populateFoldReshapeOpsByExpansionPatterns(bubbleExpandShapePatterns, | ||
bubbleUpExpansionControlFn); | ||
// Add patterns to do some additional cleanup (on top of canonicalizations | ||
// that can be done later) of reshape ops. | ||
tensor::populateFoldTensorEmptyPatterns(bubbleExpandShapePatterns); | ||
|
||
GreedyRewriteConfig rewriteConfig; | ||
rewriteConfig.maxIterations = GreedyRewriteConfig::kNoLimit; | ||
if (failed(applyPatternsAndFoldGreedily(getOperation(), | ||
std::move(bubbleExpandShapePatterns), | ||
rewriteConfig))) { | ||
getOperation()->emitOpError("Failed to perform elementwise operations"); | ||
return signalPassFailure(); | ||
} | ||
} | ||
|
||
} // namespace mlir::iree_compiler::IREE::Flow |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
87 changes: 87 additions & 0 deletions
87
compiler/src/iree/compiler/Dialect/Flow/Transforms/ElementwiseOpFusion.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
// Copyright 2024 The IREE Authors | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
//===--- ElementwiseOpFusion.cpp --- Pass to fuse elementwise ops --------===// | ||
// | ||
// This pass applies the elementwise operation fusion transformation in Linalg | ||
// with a IREE-custom cost function. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "iree/compiler/Dialect/Flow/Transforms/FusionUtils.h" | ||
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h" | ||
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h" | ||
#include "llvm/Support/Debug.h" | ||
#include "mlir/Dialect/Affine/IR/AffineOps.h" | ||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h" | ||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
|
||
#define DEBUG_TYPE "iree-flow-elementwise-op-fusion" | ||
|
||
namespace mlir::iree_compiler::IREE::Flow { | ||
|
||
#define GEN_PASS_DEF_ELEMENTWISEOPFUSIONPASS | ||
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h.inc" | ||
|
||
namespace { | ||
|
||
class ElementwiseOpFusionPass | ||
: public impl::ElementwiseOpFusionPassBase<ElementwiseOpFusionPass> { | ||
|
||
public: | ||
using Base::Base; | ||
|
||
void runOnOperation() override; | ||
}; | ||
|
||
} // namespace | ||
|
||
void ElementwiseOpFusionPass::runOnOperation() { | ||
MLIRContext *context = &getContext(); | ||
|
||
RewritePatternSet fusionPatterns(context); | ||
// Only fuse operations where all uses of the producer are generic | ||
// operations. If an operation is used in a named op, it will be computed | ||
// anyway, so the consumers can just use that value. | ||
linalg::ControlFusionFn fuseElementwiseOpsControlFn = | ||
[&](OpOperand *fusedOperand) { | ||
Operation *producer = fusedOperand->get().getDefiningOp(); | ||
Operation *consumer = fusedOperand->getOwner(); | ||
|
||
if (!isNonNullAndOutsideDispatch({producer, consumer})) { | ||
return false; | ||
} | ||
|
||
// Limit the number of operands. We have hard limit (32) of bindings | ||
// passing down to HAL. Set the number to be as same as the limit -- | ||
// IREE_HAL_MODULE_MAX_DESCRIPTOR_BINDING_COUNT. | ||
constexpr int64_t kIreeMaxOperandCount = 32; | ||
DenseSet<Value> operands; | ||
operands.insert(producer->operand_begin(), producer->operand_end()); | ||
operands.insert(consumer->operand_begin(), | ||
std::next(consumer->operand_begin(), | ||
fusedOperand->getOperandNumber())); | ||
operands.insert(std::next(consumer->operand_begin(), | ||
fusedOperand->getOperandNumber() + 1), | ||
consumer->operand_end()); | ||
if (operands.size() >= kIreeMaxOperandCount) | ||
return false; | ||
|
||
return areFusableAsElementwiseOps(context, fusedOperand, | ||
fuseMultiReduction); | ||
}; | ||
linalg::populateElementwiseOpsFusionPatterns(fusionPatterns, | ||
fuseElementwiseOpsControlFn); | ||
GreedyRewriteConfig rewriteConfig; | ||
rewriteConfig.maxIterations = GreedyRewriteConfig::kNoLimit; | ||
if (failed(applyPatternsAndFoldGreedily( | ||
getOperation(), std::move(fusionPatterns), rewriteConfig))) { | ||
getOperation()->emitOpError("Failed to perform elementwise operations"); | ||
return signalPassFailure(); | ||
} | ||
} | ||
|
||
} // namespace mlir::iree_compiler::IREE::Flow |
Oops, something went wrong.