Skip to content

Commit

Permalink
[Flow] Move elementwise op fusion and bubble up expand shapes pattern…
Browse files Browse the repository at this point in the history
…s into their own pass. (#17068)

Currently these were run within the `FusionOfTensorOps` pass as a single
fixed point. This makes it hard to debug what exactly is happening with
fusion. Having them at separate pass boundaries allows for better
debugability (especially on large models).

This pass should mostly be an NFC and is the first step towards breaking
up `FusionOfTensors` pass.
  • Loading branch information
MaheshRavishankar authored Apr 18, 2024
1 parent 3677fbc commit f32a87c
Show file tree
Hide file tree
Showing 10 changed files with 407 additions and 221 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ iree_compiler_cc_library(
name = "Transforms",
srcs = [
"AnnotateDispatches.cpp",
"BubbleUpExpandShapes.cpp",
"CaptureDynamicDims.cpp",
"CleanupTensorShapes.cpp",
"CloneProducersIntoDispatchRegions.cpp",
Expand All @@ -41,12 +42,14 @@ iree_compiler_cc_library(
"DeduplicateExecutables.cpp",
"DispatchWithTransformDialect.cpp",
"DumpDispatchGraph.cpp",
"ElementwiseOpFusion.cpp",
"ExportBenchmarkFuncs.cpp",
"FoldUnitExtentDims.cpp",
"FormDispatchRegions.cpp",
"FormDispatchWorkgroups.cpp",
"FormScalarDispatches.cpp",
"FusionOfTensorOps.cpp",
"FusionUtils.cpp",
"InitializeEmptyTensors.cpp",
"InjectDispatchTracing.cpp",
"InjectTensorTracing.cpp",
Expand All @@ -65,6 +68,7 @@ iree_compiler_cc_library(
hdrs = [
"ConvertRegionToWorkgroups.h",
"FormDispatchRegions.h",
"FusionUtils.h",
"Passes.h",
"Passes.h.inc",
"RegionOpUtils.h",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

//===--- BubbleExpandShapes.cpp --- Pass to propagate expand shapes op up -===//
//
// This pass propagates expand_shape operations up the program (and conversely)
// sinks the collapse_shape operations down the program to get the elementwise
// operations into higher dimensionality to get better fusion.
//
//===----------------------------------------------------------------------===//

#include "iree/compiler/Dialect/Flow/Transforms/FusionUtils.h"
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#define DEBUG_TYPE "iree-flow-bubble-up-expand-shapes"

namespace mlir::iree_compiler::IREE::Flow {

#define GEN_PASS_DEF_BUBBLEUPEXPANDSHAPESPASS
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h.inc"

namespace {

class BubbleUpExpandShapesPass
: public impl::BubbleUpExpandShapesPassBase<BubbleUpExpandShapesPass> {
public:
using Base::Base;

void runOnOperation() override;
};

} // namespace

void BubbleUpExpandShapesPass::runOnOperation() {
MLIRContext *context = &getContext();

RewritePatternSet bubbleExpandShapePatterns(context);
linalg::ControlFusionFn bubbleUpExpansionControlFn =
[](OpOperand *fusedOperand) {
Operation *producer = fusedOperand->get().getDefiningOp();
Operation *consumer = fusedOperand->getOwner();
if (!isNonNullAndOutsideDispatch({producer, consumer})) {
return false;
}

// Do not fuse by expand if consumer is dequant.
if (isDequantizationLikeOp(consumer)) {
return false;
}

// Do not fuse producer generic op if it has more than one user
// or any reduction iterators.
if (auto producerGenericOp = dyn_cast<linalg::GenericOp>(producer)) {
return producerGenericOp->hasOneUse() &&
llvm::all_of(producerGenericOp.getIteratorTypesArray(),
linalg::isParallelIterator);
}

// Do not fuse with any producer linalg named ops for now.
if (isa<linalg::LinalgOp>(producer)) {
return false;
}

// Do not fuse with consumer linalg named ops or reductions.
if (auto consumerLinalgOp = dyn_cast<linalg::LinalgOp>(consumer)) {
return isa<linalg::GenericOp>(consumerLinalgOp) &&
llvm::all_of(consumerLinalgOp.getIteratorTypesArray(),
linalg::isParallelIterator);
}
// Fuse in all other cases.
return true;
};
linalg::populateFoldReshapeOpsByExpansionPatterns(bubbleExpandShapePatterns,
bubbleUpExpansionControlFn);
// Add patterns to do some additional cleanup (on top of canonicalizations
// that can be done later) of reshape ops.
tensor::populateFoldTensorEmptyPatterns(bubbleExpandShapePatterns);

GreedyRewriteConfig rewriteConfig;
rewriteConfig.maxIterations = GreedyRewriteConfig::kNoLimit;
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(bubbleExpandShapePatterns),
rewriteConfig))) {
getOperation()->emitOpError("Failed to perform elementwise operations");
return signalPassFailure();
}
}

} // namespace mlir::iree_compiler::IREE::Flow
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,13 @@ iree_cc_library(
HDRS
"ConvertRegionToWorkgroups.h"
"FormDispatchRegions.h"
"FusionUtils.h"
"Passes.h"
"Passes.h.inc"
"RegionOpUtils.h"
SRCS
"AnnotateDispatches.cpp"
"BubbleUpExpandShapes.cpp"
"CaptureDynamicDims.cpp"
"CleanupTensorShapes.cpp"
"CloneProducersIntoDispatchRegions.cpp"
Expand All @@ -40,12 +42,14 @@ iree_cc_library(
"DeduplicateExecutables.cpp"
"DispatchWithTransformDialect.cpp"
"DumpDispatchGraph.cpp"
"ElementwiseOpFusion.cpp"
"ExportBenchmarkFuncs.cpp"
"FoldUnitExtentDims.cpp"
"FormDispatchRegions.cpp"
"FormDispatchWorkgroups.cpp"
"FormScalarDispatches.cpp"
"FusionOfTensorOps.cpp"
"FusionUtils.cpp"
"InitializeEmptyTensors.cpp"
"InjectDispatchTracing.cpp"
"InjectTensorTracing.cpp"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

//===--- ElementwiseOpFusion.cpp --- Pass to fuse elementwise ops --------===//
//
// This pass applies the elementwise operation fusion transformation in Linalg
// with a IREE-custom cost function.
//
//===----------------------------------------------------------------------===//

#include "iree/compiler/Dialect/Flow/Transforms/FusionUtils.h"
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#define DEBUG_TYPE "iree-flow-elementwise-op-fusion"

namespace mlir::iree_compiler::IREE::Flow {

#define GEN_PASS_DEF_ELEMENTWISEOPFUSIONPASS
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h.inc"

namespace {

class ElementwiseOpFusionPass
: public impl::ElementwiseOpFusionPassBase<ElementwiseOpFusionPass> {

public:
using Base::Base;

void runOnOperation() override;
};

} // namespace

void ElementwiseOpFusionPass::runOnOperation() {
MLIRContext *context = &getContext();

RewritePatternSet fusionPatterns(context);
// Only fuse operations where all uses of the producer are generic
// operations. If an operation is used in a named op, it will be computed
// anyway, so the consumers can just use that value.
linalg::ControlFusionFn fuseElementwiseOpsControlFn =
[&](OpOperand *fusedOperand) {
Operation *producer = fusedOperand->get().getDefiningOp();
Operation *consumer = fusedOperand->getOwner();

if (!isNonNullAndOutsideDispatch({producer, consumer})) {
return false;
}

// Limit the number of operands. We have hard limit (32) of bindings
// passing down to HAL. Set the number to be as same as the limit --
// IREE_HAL_MODULE_MAX_DESCRIPTOR_BINDING_COUNT.
constexpr int64_t kIreeMaxOperandCount = 32;
DenseSet<Value> operands;
operands.insert(producer->operand_begin(), producer->operand_end());
operands.insert(consumer->operand_begin(),
std::next(consumer->operand_begin(),
fusedOperand->getOperandNumber()));
operands.insert(std::next(consumer->operand_begin(),
fusedOperand->getOperandNumber() + 1),
consumer->operand_end());
if (operands.size() >= kIreeMaxOperandCount)
return false;

return areFusableAsElementwiseOps(context, fusedOperand,
fuseMultiReduction);
};
linalg::populateElementwiseOpsFusionPatterns(fusionPatterns,
fuseElementwiseOpsControlFn);
GreedyRewriteConfig rewriteConfig;
rewriteConfig.maxIterations = GreedyRewriteConfig::kNoLimit;
if (failed(applyPatternsAndFoldGreedily(
getOperation(), std::move(fusionPatterns), rewriteConfig))) {
getOperation()->emitOpError("Failed to perform elementwise operations");
return signalPassFailure();
}
}

} // namespace mlir::iree_compiler::IREE::Flow
Loading

0 comments on commit f32a87c

Please sign in to comment.