2727#include  " plier/rewrites/call_lowering.hpp" 
2828#include  " plier/rewrites/canonicalize_reductions.hpp" 
2929#include  " plier/rewrites/cast_lowering.hpp" 
30+ #include  " plier/rewrites/common_opts.hpp" 
3031#include  " plier/rewrites/cse.hpp" 
3132#include  " plier/rewrites/promote_to_parallel.hpp" 
3233#include  " plier/rewrites/type_conversion.hpp" 
@@ -722,6 +723,103 @@ struct BinopRewriter : public mlir::OpRewritePattern<plier::BinOp>
722723    resolver_t  resolver;
723724};
724725
726+ struct  SimplifyExpandDims  : public  mlir ::OpRewritePattern<mlir::linalg::GenericOp>
727+ {
728+     using  mlir::OpRewritePattern<mlir::linalg::GenericOp>::OpRewritePattern;
729+ 
730+     mlir::LogicalResult matchAndRewrite (
731+         mlir::linalg::GenericOp op, mlir::PatternRewriter &rewriter) const  override 
732+     {
733+         if  (!op.hasTensorSemantics ())
734+         {
735+             return  mlir::failure ();
736+         }
737+         if  (op.getNumInputs () != 1  || op.getNumOutputs () != 1 )
738+         {
739+             return  mlir::failure ();
740+         }
741+ 
742+         auto  context = op.getContext ();
743+         auto  parallel_attr = mlir::StringAttr::get (context, " parallel"  );
744+         if  (llvm::any_of (op.iterator_types (), [&](auto  attr) { return   attr != parallel_attr; }))
745+         {
746+             return  mlir::failure ();
747+         }
748+ 
749+         auto  maps = op.indexing_maps ();
750+         assert (maps.size () == 2 );
751+         auto  out_map = maps[1 ].cast <mlir::AffineMapAttr>().getValue ();
752+         if  (!out_map.isIdentity ())
753+         {
754+             return  mlir::failure ();
755+         }
756+         auto  in_map = maps[0 ].cast <mlir::AffineMapAttr>().getValue ();
757+         auto  num_dims = op.getNumLoops ();
758+         if  (in_map.getNumResults () != num_dims)
759+         {
760+             return  mlir::failure ();
761+         }
762+ 
763+         bool  changed = false ;
764+         auto  out_shape = op.getOutput (0 ).getType ().cast <mlir::RankedTensorType>().getShape ();
765+         llvm::SmallVector<mlir::AffineExpr> exprs (num_dims);
766+         for  (unsigned  i = 0 ; i < num_dims; ++i)
767+         {
768+             auto  prev_expr = in_map.getResult (i);
769+             bool  can_convert = [&]()
770+             {
771+                 if  (out_shape[i] == 1 )
772+                 {
773+                     auto  const_expr = prev_expr.dyn_cast <mlir::AffineConstantExpr>();
774+                     if  (const_expr && const_expr.getValue () == 0 )
775+                     {
776+                         return  true ;
777+                     }
778+                 }
779+                 return  false ;
780+             }();
781+             if  (can_convert)
782+             {
783+                 changed = true ;
784+                 exprs[i] = mlir::getAffineDimExpr (i, context);
785+             }
786+             else 
787+             {
788+                 exprs[i] = prev_expr;
789+             }
790+         }
791+ 
792+         if  (changed)
793+         {
794+             const  mlir::Attribute new_maps[] = {
795+                 mlir::AffineMapAttr::get (mlir::AffineMap::get (num_dims, 0 , exprs, context)),
796+                 maps[1 ]
797+             };
798+             auto  new_maps_attr = mlir::ArrayAttr::get (context, new_maps);
799+             rewriter.updateRootInPlace (op, [&]()
800+             {
801+                 op.indexing_mapsAttr (new_maps_attr);
802+             });
803+         }
804+ 
805+         return  mlir::success (changed);
806+     }
807+ };
808+ 
809+ struct  LowerEnforceShape  : public  mlir ::OpRewritePattern<plier::EnforceShapeOp>
810+ {
811+     using  mlir::OpRewritePattern<plier::EnforceShapeOp>::OpRewritePattern;
812+ 
813+     mlir::LogicalResult matchAndRewrite (
814+         plier::EnforceShapeOp op, mlir::PatternRewriter &rewriter) const  override 
815+     {
816+         auto  type = op.getType ();
817+         auto  src = op.value ();
818+         rewriter.replaceOpWithNewOp <mlir::tensor::CastOp>(op, type, src);
819+         return  mlir::success ();
820+     }
821+ };
822+ 
725823void  PlierToLinalgPass::runOnOperation ()
726824{
727825    auto  context = &getContext ();
@@ -801,38 +899,61 @@ void LowerLinalgPass::runOnOperation()
801899    (void )mlir::applyPatternsAndFoldGreedily (getOperation (), std::move (patterns));
802900}
803901
804- struct  CommonOptPass  :
805-     public mlir::PassWrapper<CommonOptPass , mlir::OperationPass<mlir::ModuleOp>>
902+ struct  PostPlierToLinalgPass  :
903+     public mlir::PassWrapper<PostPlierToLinalgPass , mlir::OperationPass<mlir::ModuleOp>>
806904{
807-     virtual  void  getDependentDialects (
808-         mlir::DialectRegistry ®istry) const  override 
809-     {
810-         registry.insert <mlir::StandardOpsDialect>();
811-         registry.insert <mlir::linalg::LinalgDialect>();
812-         registry.insert <mlir::scf::SCFDialect>();
813-         registry.insert <mlir::AffineDialect>();
814-     }
905+     void  runOnOperation () override ;
906+ };
907+ 
908+ void  PostPlierToLinalgPass::runOnOperation ()
909+ {
910+     mlir::OwningRewritePatternList patterns;
911+ 
912+     auto & context = getContext ();
913+     plier::populate_common_opts_patterns (context, patterns);
815914
915+     patterns.insert <
916+         SimplifyExpandDims
917+         >(&getContext ());
918+ 
919+     (void )mlir::applyPatternsAndFoldGreedily (getOperation (), std::move (patterns));
920+ }
921+ 
922+ struct  TensorFusionPass  :
923+     public mlir::PassWrapper<TensorFusionPass, mlir::OperationPass<mlir::ModuleOp>>
924+ {
816925    void  runOnOperation () override ;
817926};
818927
819- void  CommonOptPass ::runOnOperation ()
928+ void  TensorFusionPass ::runOnOperation ()
820929{
821930    mlir::OwningRewritePatternList patterns;
822931
823932    auto & context = getContext ();
824-     for  (auto  *op : context.getRegisteredOperations ())
825-     {
826-         op->getCanonicalizationPatterns (patterns, &context);
827-     }
933+     plier::populate_common_opts_patterns (context, patterns);
828934
829935    patterns.insert <
830-         //         LoopInvariantCodeMotion, TODO
831-         plier::ForceInline,
832-         plier::CSERewrite<mlir::FuncOp>
833-         >(&context);
936+         SimplifyExpandDims,
937+         LowerEnforceShape
938+         >(&getContext ());
939+ 
940+     mlir::populateLinalgTensorOpsFusionPatterns (&context, patterns);
834941
835-     plier::populate_index_propagate_patterns (context, patterns);
942+     (void )mlir::applyPatternsAndFoldGreedily (getOperation (), std::move (patterns));
943+ }
944+ 
945+ struct  CommonOptPass  :
946+     public mlir::PassWrapper<CommonOptPass, mlir::OperationPass<mlir::ModuleOp>>
947+ {
948+     void  runOnOperation () override ;
949+ };
950+ 
951+ void  CommonOptPass::runOnOperation ()
952+ {
953+     mlir::OwningRewritePatternList patterns;
954+ 
955+     auto & context = getContext ();
956+     plier::populate_common_opts_patterns (context, patterns);
836957
837958    (void )mlir::applyPatternsAndFoldGreedily (getOperation (), std::move (patterns));
838959}
@@ -897,15 +1018,6 @@ void RetainArgsPass::runOnFunction()
8971018struct  PostLinalgOptPass  :
8981019    public mlir::PassWrapper<PostLinalgOptPass, mlir::OperationPass<mlir::ModuleOp>>
8991020{
900-     virtual  void  getDependentDialects (
901-         mlir::DialectRegistry ®istry) const  override 
902-     {
903-         registry.insert <mlir::StandardOpsDialect>();
904-         registry.insert <mlir::linalg::LinalgDialect>();
905-         registry.insert <mlir::scf::SCFDialect>();
906-         registry.insert <mlir::AffineDialect>();
907-     }
908- 
9091021    void  runOnOperation () override ;
9101022};
9111023
@@ -914,35 +1026,26 @@ void PostLinalgOptPass::runOnOperation()
9141026    mlir::OwningRewritePatternList patterns;
9151027
9161028    auto & context = getContext ();
917-     for  (auto  *op : context.getRegisteredOperations ())
918-     {
919-         op->getCanonicalizationPatterns (patterns, &context);
920-     }
1029+     plier::populate_common_opts_patterns (context, patterns);
9211030
9221031    patterns.insert <
9231032        plier::CanonicalizeReduction,
924- //         LoopInvariantCodeMotion, TODO
925-         plier::PromoteToParallel,
926-         plier::CmpLoopBoundsSimplify,
927-         plier::CSERewrite<mlir::FuncOp>
1033+         plier::PromoteToParallel
9281034        >(&context);
9291035
930-     plier::populate_index_propagate_patterns (context, patterns);
931- 
9321036    (void )mlir::applyPatternsAndFoldGreedily (getOperation (), std::move (patterns));
9331037}
9341038
9351039void  populate_plier_to_linalg_gen_pipeline (mlir::OpPassManager& pm)
9361040{
9371041    pm.addPass (std::make_unique<PlierToLinalgPass>());
938-     pm.addPass (std::make_unique<CommonOptPass >());
1042+     pm.addPass (std::make_unique<PostPlierToLinalgPass >());
9391043    pm.addPass (mlir::createSymbolDCEPass ());
9401044}
9411045
9421046void  populate_plier_to_linalg_opt_pipeline (mlir::OpPassManager& pm)
9431047{
944-     pm.addPass (mlir::createLinalgFusionOfTensorOpsPass ());
945-     pm.addPass (std::make_unique<CommonOptPass>());
1048+     pm.addPass (std::make_unique<TensorFusionPass>());
9461049
9471050    pm.addPass (mlir::createTensorConstantBufferizePass ());
9481051    pm.addNestedPass <mlir::FuncOp>(mlir::createSCFBufferizePass ());
@@ -958,10 +1061,14 @@ void populate_plier_to_linalg_opt_pipeline(mlir::OpPassManager& pm)
9581061
9591062    pm.addNestedPass <mlir::FuncOp>(std::make_unique<RetainArgsPass>());
9601063    pm.addNestedPass <mlir::FuncOp>(mlir::createBufferDeallocationPass ());
1064+     pm.addPass (mlir::createCopyRemovalPass ());
9611065
9621066    pm.addPass (std::make_unique<LowerLinalgPass>());
1067+     pm.addPass (mlir::createParallelLoopFusionPass ());
9631068    pm.addPass (std::make_unique<PostLinalgOptPass>());
9641069    pm.addPass (mlir::createSymbolDCEPass ());
1070+     pm.addPass (mlir::createParallelLoopFusionPass ()); //  TODO: make this rewrite and add to PostLinalgOptPass
1071+     pm.addPass (std::make_unique<PostLinalgOptPass>());
9651072}
9661073}
9671074
0 commit comments