1111#include < mlir/Pass/PassManager.h>
1212#include < mlir/Pass/Pass.h>
1313#include < mlir/Transforms/GreedyPatternRewriteDriver.h>
14+ #include < mlir/Transforms/Passes.h>
1415
1516#include < llvm/ADT/Triple.h>
1617#include < llvm/ADT/TypeSwitch.h>
@@ -871,6 +872,7 @@ struct LowerParallel : public mlir::OpRewritePattern<plier::ParallelOp>
871872 mlir::LogicalResult
872873 matchAndRewrite (plier::ParallelOp op,
873874 mlir::PatternRewriter &rewriter) const override {
875+ auto num_loops = op.getNumLoops ();
874876 llvm::SmallVector<mlir::Value> context_vars;
875877 llvm::SmallVector<mlir::Operation*> context_constants;
876878 llvm::DenseSet<mlir::Value> context_vars_set;
@@ -951,6 +953,24 @@ struct LowerParallel : public mlir::OpRewritePattern<plier::ParallelOp>
951953 auto context_ptr_type = mlir::LLVM::LLVMPointerType::get (context_type);
952954
953955 auto loc = op.getLoc ();
956+ auto index_type = rewriter.getIndexType ();
957+ auto llvm_index_type = mlir::IntegerType::get (op.getContext (), 64 ); // TODO
958+ auto to_llvm_index = [&](mlir::Value val)->mlir ::Value
959+ {
960+ if (val.getType () != llvm_index_type)
961+ {
962+ return rewriter.create <mlir::LLVM::BitcastOp>(loc, llvm_index_type, val);
963+ }
964+ return val;
965+ };
966+ auto from_llvm_index = [&](mlir::Value val)->mlir ::Value
967+ {
968+ if (val.getType () != index_type)
969+ {
970+ return rewriter.create <plier::CastOp>(loc, index_type, val);
971+ }
972+ return val;
973+ };
954974 auto llvm_i32_type = mlir::IntegerType::get (op.getContext (), 32 );
955975 auto zero = rewriter.create <mlir::LLVM::ConstantOp>(loc, llvm_i32_type, rewriter.getI32IntegerAttr (0 ));
956976 auto one = rewriter.create <mlir::LLVM::ConstantOp>(loc, llvm_i32_type, rewriter.getI32IntegerAttr (1 ));
@@ -971,12 +991,29 @@ struct LowerParallel : public mlir::OpRewritePattern<plier::ParallelOp>
971991 auto void_ptr_type = mlir::LLVM::LLVMPointerType::get (mlir::IntegerType::get (op.getContext (), 8 ));
972992 auto context_abstract = rewriter.create <mlir::LLVM::BitcastOp>(loc, void_ptr_type, context);
973993
974- auto index_type = rewriter.getIndexType ();
994+ auto input_range_type = [&]()
995+ {
996+ const mlir::Type members[] = {
997+ llvm_index_type, // lower_bound
998+ llvm_index_type, // upper_bound
999+ llvm_index_type, // step
1000+ };
1001+ return mlir::LLVM::LLVMStructType::getLiteral (op.getContext (), members);
1002+ }();
1003+ auto input_range_ptr = mlir::LLVM::LLVMPointerType::get (input_range_type);
1004+ auto range_type = [&]()
1005+ {
1006+ const mlir::Type members[] = {
1007+ llvm_index_type, // lower_bound
1008+ llvm_index_type, // upper_bound
1009+ };
1010+ return mlir::LLVM::LLVMStructType::getLiteral (op.getContext (), members);
1011+ }();
1012+ auto range_ptr = mlir::LLVM::LLVMPointerType::get (range_type);
9751013 auto func_type = [&]()
9761014 {
977- mlir::Type args[] = {
978- index_type, // lower_bound
979- index_type, // upper_bound
1015+ const mlir::Type args[] = {
1016+ range_ptr, // bounds
9801017 index_type, // thread index
9811018 void_ptr_type // context
9821019 };
@@ -1014,21 +1051,34 @@ struct LowerParallel : public mlir::OpRewritePattern<plier::ParallelOp>
10141051 auto entry = func.addEntryBlock ();
10151052 auto loc = rewriter.getUnknownLoc ();
10161053 mlir::OpBuilder::InsertionGuard guard (rewriter);
1017- mapping.map (old_entry.getArgument (0 ), entry->getArgument (0 ));
1018- mapping.map (old_entry.getArgument (1 ), entry->getArgument (1 ));
1019- mapping.map (old_entry.getArgument (2 ), entry->getArgument (2 ));
10201054 rewriter.setInsertionPointToStart (entry);
1055+ auto pos0 = rewriter.getI64ArrayAttr (0 );
1056+ auto pos1 = rewriter.getI64ArrayAttr (1 );
1057+ for (unsigned i = 0 ; i < num_loops; ++i)
1058+ {
1059+ auto arg = entry->getArgument (0 );
1060+ const mlir::Value indices[] = {
1061+ rewriter.create <mlir::LLVM::ConstantOp>(loc, llvm_i32_type, rewriter.getI32IntegerAttr (static_cast <int32_t >(i)))
1062+ };
1063+ auto ptr = rewriter.create <mlir::LLVM::GEPOp>(loc, range_ptr, arg, indices);
1064+ auto dims = rewriter.create <mlir::LLVM::LoadOp>(loc, ptr);
1065+ auto lower = rewriter.create <mlir::LLVM::ExtractValueOp>(loc, llvm_index_type, dims, pos0);
1066+ auto upper = rewriter.create <mlir::LLVM::ExtractValueOp>(loc, llvm_index_type, dims, pos1);
1067+ mapping.map (old_entry.getArgument (i), from_llvm_index (lower));
1068+ mapping.map (old_entry.getArgument (i + num_loops), from_llvm_index (upper));
1069+ }
1070+ mapping.map (old_entry.getArgument (2 * num_loops), entry->getArgument (1 )); // thread index
10211071 for (auto arg : context_constants)
10221072 {
10231073 rewriter.clone (*arg, mapping);
10241074 }
1025- auto context_ptr = rewriter.create <mlir::LLVM::BitcastOp>(loc, context_ptr_type, entry->getArgument (3 ));
1075+ auto context_ptr = rewriter.create <mlir::LLVM::BitcastOp>(loc, context_ptr_type, entry->getArgument (2 ));
10261076 auto zero = rewriter.create <mlir::LLVM::ConstantOp>(loc, llvm_i32_type, rewriter.getI32IntegerAttr (0 ));
10271077 for (auto it : llvm::enumerate (context_vars))
10281078 {
10291079 auto index = it.index ();
10301080 auto old_val = it.value ();
1031- mlir::Value indices[] = {
1081+ const mlir::Value indices[] = {
10321082 zero,
10331083 rewriter.create <mlir::LLVM::ConstantOp>(loc, llvm_i32_type, rewriter.getI32IntegerAttr (static_cast <int32_t >(index)))
10341084 };
@@ -1060,21 +1110,39 @@ struct LowerParallel : public mlir::OpRewritePattern<plier::ParallelOp>
10601110 {
10611111 return sym;
10621112 }
1063- mlir::Type args[] = {
1064- index_type, // lower bound
1065- index_type, // upper bound
1066- index_type, // step
1067- func_type,
1068- void_ptr_type
1113+ const mlir::Type args[] = {
1114+ input_range_ptr, // bounds
1115+ index_type, // num_loops
1116+ func_type, // func
1117+ void_ptr_type // context
10691118 };
1070- auto func_type = mlir::FunctionType::get (op.getContext (), args, {});
1071- return plier::add_function (rewriter, mod, func_name, func_type );
1119+ auto parallel_func_type = mlir::FunctionType::get (op.getContext (), args, {});
1120+ return plier::add_function (rewriter, mod, func_name, parallel_func_type );
10721121 }();
10731122 auto func_addr = rewriter.create <mlir::ConstantOp>(loc, func_type, rewriter.getSymbolRefAttr (outlined_func));
1074- mlir::Value pf_args[] = {
1075- op.lowerBound (),
1076- op.upperBound (),
1077- op.step (),
1123+
1124+ auto num_loops_var = rewriter.create <mlir::ConstantIndexOp>(loc, num_loops);
1125+ auto input_ranges = rewriter.create <mlir::LLVM::AllocaOp>(loc, input_range_ptr, to_llvm_index (num_loops_var), 0 );
1126+ for (unsigned i = 0 ; i < num_loops; ++i)
1127+ {
1128+ mlir::Value input_range = rewriter.create <mlir::LLVM::UndefOp>(loc, input_range_type);
1129+ auto insert = [&](mlir::Value val, unsigned index)
1130+ {
1131+ input_range = rewriter.create <mlir::LLVM::InsertValueOp>(loc, input_range, val, rewriter.getI64ArrayAttr (index));
1132+ };
1133+ insert (to_llvm_index (op.lowerBounds ()[i]), 0 );
1134+ insert (to_llvm_index (op.upperBounds ()[i]), 1 );
1135+ insert (to_llvm_index (op.steps ()[i]), 2 );
1136+ const mlir::Value indices[] = {
1137+ rewriter.create <mlir::LLVM::ConstantOp>(loc, llvm_i32_type, rewriter.getI32IntegerAttr (static_cast <int >(i)))
1138+ };
1139+ auto ptr = rewriter.create <mlir::LLVM::GEPOp>(loc, input_range_ptr, input_ranges, indices);
1140+ rewriter.create <mlir::LLVM::StoreOp>(loc, input_range, ptr);
1141+ }
1142+
1143+ const mlir::Value pf_args[] = {
1144+ input_ranges,
1145+ num_loops_var,
10781146 func_addr,
10791147 context_abstract
10801148 };
@@ -1226,6 +1294,7 @@ void populate_lower_to_llvm_pipeline(mlir::OpPassManager& pm)
12261294{
12271295 pm.addPass (std::make_unique<LowerParallelToCFGPass>());
12281296 pm.addPass (mlir::createLowerToCFGPass ());
1297+ pm.addPass (mlir::createCanonicalizerPass ());
12291298// pm.addPass(std::make_unique<CheckForPlierTypes>());
12301299 pm.addNestedPass <mlir::FuncOp>(std::make_unique<PreLLVMLowering>());
12311300 pm.addPass (std::make_unique<LLVMLoweringPass>(getLLVMOptions ()));
0 commit comments