Skip to content
This repository was archived by the owner on Jan 25, 2023. It is now read-only.

Commit bf8a729

Browse files
author
Ivan Butygin
authored
[MLIR] Fusing optimizations (#200)
* PostPlierToLinalgPass pass * refactor out populate_common_opts_patterns * propagate if const values * simplify expand dims * plier inliner interface * simplify if and select * fixes * fixes * EnforceShapeOp * copy removal pass and parallel loop fusion * fixes * more SelectOp folding * Add non-recursive CSE * more efficient reshape * run few rounds ParallelOp fusing
1 parent 4c147f6 commit bf8a729

File tree

12 files changed

+555
-59
lines changed

12 files changed

+555
-59
lines changed

mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp

Lines changed: 149 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "plier/rewrites/call_lowering.hpp"
2828
#include "plier/rewrites/canonicalize_reductions.hpp"
2929
#include "plier/rewrites/cast_lowering.hpp"
30+
#include "plier/rewrites/common_opts.hpp"
3031
#include "plier/rewrites/cse.hpp"
3132
#include "plier/rewrites/promote_to_parallel.hpp"
3233
#include "plier/rewrites/type_conversion.hpp"
@@ -722,6 +723,103 @@ struct BinopRewriter : public mlir::OpRewritePattern<plier::BinOp>
722723
resolver_t resolver;
723724
};
724725

726+
struct SimplifyExpandDims : public mlir::OpRewritePattern<mlir::linalg::GenericOp>
727+
{
728+
using mlir::OpRewritePattern<mlir::linalg::GenericOp>::OpRewritePattern;
729+
730+
mlir::LogicalResult matchAndRewrite(
731+
mlir::linalg::GenericOp op, mlir::PatternRewriter &rewriter) const override
732+
{
733+
if (!op.hasTensorSemantics())
734+
{
735+
return mlir::failure();
736+
}
737+
if (op.getNumInputs() != 1 || op.getNumOutputs() != 1)
738+
{
739+
return mlir::failure();
740+
}
741+
742+
auto context = op.getContext();
743+
auto parallel_attr = mlir::StringAttr::get(context, "parallel");
744+
if (llvm::any_of(op.iterator_types(), [&](auto attr) { return attr != parallel_attr; }))
745+
{
746+
return mlir::failure();
747+
}
748+
749+
auto maps = op.indexing_maps();
750+
assert(maps.size() == 2);
751+
auto out_map = maps[1].cast<mlir::AffineMapAttr>().getValue();
752+
if (!out_map.isIdentity())
753+
{
754+
return mlir::failure();
755+
}
756+
auto in_map = maps[0].cast<mlir::AffineMapAttr>().getValue();
757+
auto num_dims = op.getNumLoops();
758+
if (in_map.getNumResults() != num_dims)
759+
{
760+
return mlir::failure();
761+
}
762+
763+
bool changed = false;
764+
auto out_shape = op.getOutput(0).getType().cast<mlir::RankedTensorType>().getShape();
765+
llvm::SmallVector<mlir::AffineExpr> exprs(num_dims);
766+
for (unsigned i = 0; i < num_dims; ++i)
767+
{
768+
auto prev_expr = in_map.getResult(i);
769+
bool can_convert = [&]()
770+
{
771+
if (out_shape[i] == 1)
772+
{
773+
auto const_expr = prev_expr.dyn_cast<mlir::AffineConstantExpr>();
774+
if (const_expr && const_expr.getValue() == 0)
775+
{
776+
return true;
777+
}
778+
}
779+
return false;
780+
}();
781+
if (can_convert)
782+
{
783+
changed = true;
784+
exprs[i] = mlir::getAffineDimExpr(i, context);
785+
}
786+
else
787+
{
788+
exprs[i] = prev_expr;
789+
}
790+
}
791+
792+
if (changed)
793+
{
794+
const mlir::Attribute new_maps[] = {
795+
mlir::AffineMapAttr::get(mlir::AffineMap::get(num_dims, 0, exprs, context)),
796+
maps[1]
797+
};
798+
auto new_maps_attr = mlir::ArrayAttr::get(context, new_maps);
799+
rewriter.updateRootInPlace(op, [&]()
800+
{
801+
op.indexing_mapsAttr(new_maps_attr);
802+
});
803+
}
804+
805+
return mlir::success(changed);
806+
}
807+
};
808+
809+
struct LowerEnforceShape : public mlir::OpRewritePattern<plier::EnforceShapeOp>
810+
{
811+
using mlir::OpRewritePattern<plier::EnforceShapeOp>::OpRewritePattern;
812+
813+
mlir::LogicalResult matchAndRewrite(
814+
plier::EnforceShapeOp op, mlir::PatternRewriter &rewriter) const override
815+
{
816+
auto type = op.getType();
817+
auto src = op.value();
818+
rewriter.replaceOpWithNewOp<mlir::tensor::CastOp>(op, type, src);
819+
return mlir::success();
820+
}
821+
};
822+
725823
void PlierToLinalgPass::runOnOperation()
726824
{
727825
auto context = &getContext();
@@ -801,38 +899,61 @@ void LowerLinalgPass::runOnOperation()
801899
(void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
802900
}
803901

804-
struct CommonOptPass :
805-
public mlir::PassWrapper<CommonOptPass, mlir::OperationPass<mlir::ModuleOp>>
902+
struct PostPlierToLinalgPass :
903+
public mlir::PassWrapper<PostPlierToLinalgPass, mlir::OperationPass<mlir::ModuleOp>>
806904
{
807-
virtual void getDependentDialects(
808-
mlir::DialectRegistry &registry) const override
809-
{
810-
registry.insert<mlir::StandardOpsDialect>();
811-
registry.insert<mlir::linalg::LinalgDialect>();
812-
registry.insert<mlir::scf::SCFDialect>();
813-
registry.insert<mlir::AffineDialect>();
814-
}
905+
void runOnOperation() override;
906+
};
907+
908+
void PostPlierToLinalgPass::runOnOperation()
909+
{
910+
mlir::OwningRewritePatternList patterns;
911+
912+
auto& context = getContext();
913+
plier::populate_common_opts_patterns(context, patterns);
815914

915+
patterns.insert<
916+
SimplifyExpandDims
917+
>(&getContext());
918+
919+
(void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
920+
}
921+
922+
struct TensorFusionPass :
923+
public mlir::PassWrapper<TensorFusionPass, mlir::OperationPass<mlir::ModuleOp>>
924+
{
816925
void runOnOperation() override;
817926
};
818927

819-
void CommonOptPass::runOnOperation()
928+
void TensorFusionPass::runOnOperation()
820929
{
821930
mlir::OwningRewritePatternList patterns;
822931

823932
auto& context = getContext();
824-
for (auto *op : context.getRegisteredOperations())
825-
{
826-
op->getCanonicalizationPatterns(patterns, &context);
827-
}
933+
plier::populate_common_opts_patterns(context, patterns);
828934

829935
patterns.insert<
830-
// LoopInvariantCodeMotion, TODO
831-
plier::ForceInline,
832-
plier::CSERewrite<mlir::FuncOp>
833-
>(&context);
936+
SimplifyExpandDims,
937+
LowerEnforceShape
938+
>(&getContext());
939+
940+
mlir::populateLinalgTensorOpsFusionPatterns(&context, patterns);
834941

835-
plier::populate_index_propagate_patterns(context, patterns);
942+
(void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
943+
}
944+
945+
struct CommonOptPass :
946+
public mlir::PassWrapper<CommonOptPass, mlir::OperationPass<mlir::ModuleOp>>
947+
{
948+
void runOnOperation() override;
949+
};
950+
951+
void CommonOptPass::runOnOperation()
952+
{
953+
mlir::OwningRewritePatternList patterns;
954+
955+
auto& context = getContext();
956+
plier::populate_common_opts_patterns(context, patterns);
836957

837958
(void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
838959
}
@@ -897,15 +1018,6 @@ void RetainArgsPass::runOnFunction()
8971018
struct PostLinalgOptPass :
8981019
public mlir::PassWrapper<PostLinalgOptPass, mlir::OperationPass<mlir::ModuleOp>>
8991020
{
900-
virtual void getDependentDialects(
901-
mlir::DialectRegistry &registry) const override
902-
{
903-
registry.insert<mlir::StandardOpsDialect>();
904-
registry.insert<mlir::linalg::LinalgDialect>();
905-
registry.insert<mlir::scf::SCFDialect>();
906-
registry.insert<mlir::AffineDialect>();
907-
}
908-
9091021
void runOnOperation() override;
9101022
};
9111023

@@ -914,35 +1026,26 @@ void PostLinalgOptPass::runOnOperation()
9141026
mlir::OwningRewritePatternList patterns;
9151027

9161028
auto& context = getContext();
917-
for (auto *op : context.getRegisteredOperations())
918-
{
919-
op->getCanonicalizationPatterns(patterns, &context);
920-
}
1029+
plier::populate_common_opts_patterns(context, patterns);
9211030

9221031
patterns.insert<
9231032
plier::CanonicalizeReduction,
924-
// LoopInvariantCodeMotion, TODO
925-
plier::PromoteToParallel,
926-
plier::CmpLoopBoundsSimplify,
927-
plier::CSERewrite<mlir::FuncOp>
1033+
plier::PromoteToParallel
9281034
>(&context);
9291035

930-
plier::populate_index_propagate_patterns(context, patterns);
931-
9321036
(void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
9331037
}
9341038

9351039
void populate_plier_to_linalg_gen_pipeline(mlir::OpPassManager& pm)
9361040
{
9371041
pm.addPass(std::make_unique<PlierToLinalgPass>());
938-
pm.addPass(std::make_unique<CommonOptPass>());
1042+
pm.addPass(std::make_unique<PostPlierToLinalgPass>());
9391043
pm.addPass(mlir::createSymbolDCEPass());
9401044
}
9411045

9421046
void populate_plier_to_linalg_opt_pipeline(mlir::OpPassManager& pm)
9431047
{
944-
pm.addPass(mlir::createLinalgFusionOfTensorOpsPass());
945-
pm.addPass(std::make_unique<CommonOptPass>());
1048+
pm.addPass(std::make_unique<TensorFusionPass>());
9461049

9471050
pm.addPass(mlir::createTensorConstantBufferizePass());
9481051
pm.addNestedPass<mlir::FuncOp>(mlir::createSCFBufferizePass());
@@ -958,10 +1061,14 @@ void populate_plier_to_linalg_opt_pipeline(mlir::OpPassManager& pm)
9581061

9591062
pm.addNestedPass<mlir::FuncOp>(std::make_unique<RetainArgsPass>());
9601063
pm.addNestedPass<mlir::FuncOp>(mlir::createBufferDeallocationPass());
1064+
pm.addPass(mlir::createCopyRemovalPass());
9611065

9621066
pm.addPass(std::make_unique<LowerLinalgPass>());
1067+
pm.addPass(mlir::createParallelLoopFusionPass());
9631068
pm.addPass(std::make_unique<PostLinalgOptPass>());
9641069
pm.addPass(mlir::createSymbolDCEPass());
1070+
pm.addPass(mlir::createParallelLoopFusionPass()); // TODO: make this rewrite and add to PostLinalgOptPass
1071+
pm.addPass(std::make_unique<PostLinalgOptPass>());
9651072
}
9661073
}
9671074

mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -452,8 +452,9 @@ mlir::Value expand_dim(mlir::OpBuilder& builder, mlir::Location loc, mlir::Value
452452
{
453453
assert(dim < shape.size());
454454
shape[dim] = 1;
455-
mlir::Type casted_type = mlir::RankedTensorType::get(shape, src_type.getElementType());
456-
auto casted = builder.create<mlir::tensor::CastOp>(loc, casted_type, src).getResult();
455+
// mlir::Type casted_type = mlir::RankedTensorType::get(shape, src_type.getElementType());
456+
// auto casted = builder.create<mlir::tensor::CastOp>(loc, casted_type, src).getResult();
457+
auto casted = src; // TODO
457458
auto init = builder.create<mlir::linalg::InitTensorOp>(loc, new_shape, src_type.getElementType()).getResult();
458459
llvm::SmallVector<mlir::AffineExpr> exprs(num_dims);
459460
for (unsigned i = 0; i < num_dims; ++i)
@@ -503,6 +504,7 @@ mlir::Value expand_dims(mlir::OpBuilder& builder, mlir::Location loc, mlir::Valu
503504
{
504505
current = expand_dim(builder, loc, val, current, i, target_shape);
505506
}
507+
current = builder.create<plier::EnforceShapeOp>(loc, current, target_shape);
506508
return current;
507509
}
508510

mlir-compiler/plier/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,11 @@ set(SOURCES_LIST
1818
src/rewrites/call_lowering.cpp
1919
src/rewrites/canonicalize_reductions.cpp
2020
src/rewrites/cast_lowering.cpp
21+
src/rewrites/common_opts.cpp
2122
src/rewrites/cse.cpp
2223
src/rewrites/force_inline.cpp
2324
src/rewrites/index_type_propagation.cpp
25+
include/plier/rewrites/if_rewrites.cpp
2426
src/rewrites/loop_rewrites.cpp
2527
src/rewrites/promote_to_parallel.cpp
2628
src/rewrites/type_conversion.cpp
@@ -39,9 +41,11 @@ set(HEADERS_LIST
3941
include/plier/rewrites/call_lowering.hpp
4042
include/plier/rewrites/canonicalize_reductions.hpp
4143
include/plier/rewrites/cast_lowering.hpp
44+
include/plier/rewrites/common_opts.hpp
4245
include/plier/rewrites/cse.hpp
4346
include/plier/rewrites/force_inline.hpp
4447
include/plier/rewrites/index_type_propagation.hpp
48+
include/plier/rewrites/if_rewrites.hpp
4549
include/plier/rewrites/loop_rewrites.hpp
4650
include/plier/rewrites/promote_to_parallel.hpp
4751
include/plier/rewrites/type_conversion.hpp

mlir-compiler/plier/include/plier/PlierOps.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,20 @@ def GetattrOp : Plier_Op<"getattr", [NoSideEffect]> {
214214
];
215215
}
216216

217+
def EnforceShapeOp : Plier_Op<"enforce_shape"> {
218+
let arguments = (ins AnyRankedTensor:$value,
219+
Variadic<Index>:$sizes);
220+
221+
let results = (outs AnyRankedTensor:$result);
222+
223+
let builders = [
224+
OpBuilderDAG<(ins "::mlir::Value":$value, "::mlir::ValueRange":$shape)>
225+
];
226+
227+
let hasFolder = 1;
228+
let hasCanonicalizer = 1;
229+
}
230+
217231
def RetainOp : Plier_Op<"retain"> {
218232
let arguments = (ins AnyMemRef:$value);
219233

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#pragma once
2+
3+
namespace mlir
4+
{
5+
class OwningRewritePatternList;
6+
class MLIRContext;
7+
}
8+
9+
namespace plier
10+
{
11+
void populate_common_opts_patterns(mlir::MLIRContext& context, mlir::OwningRewritePatternList& patterns);
12+
}

mlir-compiler/plier/include/plier/rewrites/cse.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@ namespace plier
77
{
88
namespace detail
99
{
10-
mlir::LogicalResult applyCSE(mlir::Region& region, mlir::PatternRewriter& rewriter);
10+
mlir::LogicalResult applyCSE(mlir::Region& region, mlir::PatternRewriter& rewriter, bool recusive);
1111
}
1212

13-
template<typename Op>
13+
template<typename Op, bool Recursive>
1414
struct CSERewrite : public mlir::OpRewritePattern<Op>
1515
{
1616
CSERewrite(mlir::MLIRContext *context):
@@ -19,7 +19,7 @@ struct CSERewrite : public mlir::OpRewritePattern<Op>
1919
mlir::LogicalResult matchAndRewrite(
2020
Op op, mlir::PatternRewriter &rewriter) const override
2121
{
22-
return ::plier::detail::applyCSE(op.getRegion(), rewriter);
22+
return ::plier::detail::applyCSE(op.getRegion(), rewriter, Recursive);
2323
}
2424
};
2525
}

0 commit comments

Comments
 (0)