@@ -133,8 +133,12 @@ bool is_int(mlir::Type type)
133133    return  type.isa <mlir::IntegerType>();
134134}
135135
136- mlir::LogicalResult lower_prange (plier::PyCallOp op, llvm::ArrayRef<mlir::Value> operands, mlir::PatternRewriter& rewriter)
136+ mlir::LogicalResult lower_prange (plier::PyCallOp op, llvm::ArrayRef<mlir::Value> operands, llvm::ArrayRef<std::pair<llvm::StringRef, mlir::Value>> kwargs,  mlir::PatternRewriter& rewriter)
137137{
138+     if  (!kwargs.empty ())
139+     {
140+         return  mlir::failure ();
141+     }
138142    if  ((operands.size () < 1  || operands.size () > 3 ) ||
139143        !llvm::all_of (operands, [](mlir::Value val) { return  is_int (val.getType ());}))
140144    {
@@ -177,21 +181,21 @@ struct CallLowerer
177181{
178182    mlir::LogicalResult operator ()(
179183        plier::PyCallOp op, llvm::StringRef name, llvm::ArrayRef<mlir::Value> args,
180-         mlir::PatternRewriter& rewriter)
184+         llvm::ArrayRef<std::pair<llvm::StringRef, mlir::Value>> kwargs,  mlir::PatternRewriter& rewriter)
181185    {
182-         using  func_t  = mlir::LogicalResult (*)(plier::PyCallOp, llvm::ArrayRef<mlir::Value>, mlir::PatternRewriter&);
186+         using  func_t  = mlir::LogicalResult (*)(plier::PyCallOp, llvm::ArrayRef<mlir::Value>, llvm::ArrayRef<std::pair<llvm::StringRef, mlir::Value>>,  mlir::PatternRewriter&);
183187        std::pair<llvm::StringRef, func_t > handlers[] = {
184188            {" numba.prange"  , lower_prange},
185189        };
186190        for  (auto & handler : handlers)
187191        {
188192            if  (handler.first  == name)
189193            {
190-                 return  handler.second (op, args, rewriter);
194+                 return  handler.second (op, args, kwargs,  rewriter);
191195            }
192196        }
193197
194-         if  (auto  result = linalg_resolver.rewrite (name, op.getLoc (), rewriter, args))
198+         if  (auto  result = linalg_resolver.rewrite (name, op.getLoc (), rewriter, args, kwargs ))
195199        {
196200            assert (result->size () == op->getNumResults ());
197201            rerun_std_pipeline (op);
@@ -206,7 +210,7 @@ struct CallLowerer
206210            return  mlir::success ();
207211        }
208212
209-         if  (name == " len"   && check_numpy_args (args, 1 ))
213+         if  (name == " len"   && check_numpy_args (args, 1 ) && kwargs. empty () )
210214        {
211215            auto  loc = op.getLoc ();
212216            mlir::Value dim = rewriter.create <mlir::DimOp>(loc, args[0 ], 0 );
@@ -219,7 +223,6 @@ struct CallLowerer
219223    }
220224
221225private: 
222- 
223226    PyLinalgResolver linalg_resolver;
224227};
225228
@@ -436,7 +439,7 @@ struct SetitemOpLowering : public mlir::OpRewritePattern<T>
436439            mlir::OpBuilder::InsertionGuard g (rewriter);
437440            if  (auto  parent_op = target.getDefiningOp ())
438441            {
439-                 rewriter.setInsertionPoint (parent_op);
442+                 rewriter.setInsertionPointAfter (parent_op);
440443            }
441444            else 
442445            {
@@ -456,6 +459,7 @@ struct SetitemOpLowering : public mlir::OpRewritePattern<T>
456459                    }
457460                    else 
458461                    {
462+                         mlir::OpBuilder::InsertionGuard g (rewriter);
459463                        rewriter.setInsertionPoint (use_op);
460464                        auto  new_val = rewriter.create <mlir::TensorLoadOp>(use_op->getLoc (), memref);
461465                        rewriter.updateRootInPlace (use_op, [&]()
@@ -602,6 +606,7 @@ struct LowerLinalgPass :
602606        mlir::DialectRegistry ®istry) const  override 
603607    {
604608        registry.insert <mlir::StandardOpsDialect>();
609+         registry.insert <mlir::tensor::TensorDialect>();
605610        registry.insert <mlir::linalg::LinalgDialect>();
606611        registry.insert <mlir::scf::SCFDialect>();
607612        registry.insert <mlir::AffineDialect>();
@@ -686,6 +691,7 @@ void PostLinalgOptPass::runOnOperation()
686691void  populate_plier_to_linalg_gen_pipeline (mlir::OpPassManager& pm)
687692{
688693    pm.addPass (std::make_unique<PlierToLinalgPass>());
694+     pm.addPass (mlir::createSymbolDCEPass ());
689695}
690696
691697void  populate_plier_to_linalg_opt_pipeline (mlir::OpPassManager& pm)
@@ -708,6 +714,7 @@ void populate_plier_to_linalg_opt_pipeline(mlir::OpPassManager& pm)
708714
709715    pm.addPass (std::make_unique<LowerLinalgPass>());
710716    pm.addPass (std::make_unique<PostLinalgOptPass>());
717+     pm.addPass (mlir::createSymbolDCEPass ());
711718}
712719}
713720
0 commit comments