Skip to content
This repository was archived by the owner on Jan 25, 2023. It is now read-only.

Commit b0900a5

Browse files
author
Ivan Butygin
authored
[MLIR] some multiindex ParallelOp supports (#195)
* change builder func * accept list of bounds * plier::ParallelOp nested loops support * fix to nested parallel loops
1 parent b3fe4a3 commit b0900a5

File tree

6 files changed

+227
-65
lines changed

6 files changed

+227
-65
lines changed

mlir-compiler/mlir-compiler/src/pipelines/lower_to_llvm.cpp

Lines changed: 90 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <mlir/Pass/PassManager.h>
1212
#include <mlir/Pass/Pass.h>
1313
#include <mlir/Transforms/GreedyPatternRewriteDriver.h>
14+
#include <mlir/Transforms/Passes.h>
1415

1516
#include <llvm/ADT/Triple.h>
1617
#include <llvm/ADT/TypeSwitch.h>
@@ -871,6 +872,7 @@ struct LowerParallel : public mlir::OpRewritePattern<plier::ParallelOp>
871872
mlir::LogicalResult
872873
matchAndRewrite(plier::ParallelOp op,
873874
mlir::PatternRewriter &rewriter) const override {
875+
auto num_loops = op.getNumLoops();
874876
llvm::SmallVector<mlir::Value> context_vars;
875877
llvm::SmallVector<mlir::Operation*> context_constants;
876878
llvm::DenseSet<mlir::Value> context_vars_set;
@@ -951,6 +953,24 @@ struct LowerParallel : public mlir::OpRewritePattern<plier::ParallelOp>
951953
auto context_ptr_type = mlir::LLVM::LLVMPointerType::get(context_type);
952954

953955
auto loc = op.getLoc();
956+
auto index_type = rewriter.getIndexType();
957+
auto llvm_index_type = mlir::IntegerType::get(op.getContext(), 64); // TODO
958+
auto to_llvm_index = [&](mlir::Value val)->mlir::Value
959+
{
960+
if (val.getType() != llvm_index_type)
961+
{
962+
return rewriter.create<mlir::LLVM::BitcastOp>(loc, llvm_index_type, val);
963+
}
964+
return val;
965+
};
966+
auto from_llvm_index = [&](mlir::Value val)->mlir::Value
967+
{
968+
if (val.getType() != index_type)
969+
{
970+
return rewriter.create<plier::CastOp>(loc, index_type, val);
971+
}
972+
return val;
973+
};
954974
auto llvm_i32_type = mlir::IntegerType::get(op.getContext(), 32);
955975
auto zero = rewriter.create<mlir::LLVM::ConstantOp>(loc, llvm_i32_type, rewriter.getI32IntegerAttr(0));
956976
auto one = rewriter.create<mlir::LLVM::ConstantOp>(loc, llvm_i32_type, rewriter.getI32IntegerAttr(1));
@@ -971,12 +991,29 @@ struct LowerParallel : public mlir::OpRewritePattern<plier::ParallelOp>
971991
auto void_ptr_type = mlir::LLVM::LLVMPointerType::get(mlir::IntegerType::get(op.getContext(), 8));
972992
auto context_abstract = rewriter.create<mlir::LLVM::BitcastOp>(loc, void_ptr_type, context);
973993

974-
auto index_type = rewriter.getIndexType();
994+
auto input_range_type = [&]()
995+
{
996+
const mlir::Type members[] = {
997+
llvm_index_type, // lower_bound
998+
llvm_index_type, // upper_bound
999+
llvm_index_type, // step
1000+
};
1001+
return mlir::LLVM::LLVMStructType::getLiteral(op.getContext(), members);
1002+
}();
1003+
auto input_range_ptr = mlir::LLVM::LLVMPointerType::get(input_range_type);
1004+
auto range_type = [&]()
1005+
{
1006+
const mlir::Type members[] = {
1007+
llvm_index_type, // lower_bound
1008+
llvm_index_type, // upper_bound
1009+
};
1010+
return mlir::LLVM::LLVMStructType::getLiteral(op.getContext(), members);
1011+
}();
1012+
auto range_ptr = mlir::LLVM::LLVMPointerType::get(range_type);
9751013
auto func_type = [&]()
9761014
{
977-
mlir::Type args[] = {
978-
index_type, // lower_bound
979-
index_type, // upper_bound
1015+
const mlir::Type args[] = {
1016+
range_ptr, // bounds
9801017
index_type, // thread index
9811018
void_ptr_type // context
9821019
};
@@ -1014,21 +1051,34 @@ struct LowerParallel : public mlir::OpRewritePattern<plier::ParallelOp>
10141051
auto entry = func.addEntryBlock();
10151052
auto loc = rewriter.getUnknownLoc();
10161053
mlir::OpBuilder::InsertionGuard guard(rewriter);
1017-
mapping.map(old_entry.getArgument(0), entry->getArgument(0));
1018-
mapping.map(old_entry.getArgument(1), entry->getArgument(1));
1019-
mapping.map(old_entry.getArgument(2), entry->getArgument(2));
10201054
rewriter.setInsertionPointToStart(entry);
1055+
auto pos0 = rewriter.getI64ArrayAttr(0);
1056+
auto pos1 = rewriter.getI64ArrayAttr(1);
1057+
for (unsigned i = 0; i < num_loops; ++i)
1058+
{
1059+
auto arg = entry->getArgument(0);
1060+
const mlir::Value indices[] = {
1061+
rewriter.create<mlir::LLVM::ConstantOp>(loc, llvm_i32_type, rewriter.getI32IntegerAttr(static_cast<int32_t>(i)))
1062+
};
1063+
auto ptr = rewriter.create<mlir::LLVM::GEPOp>(loc, range_ptr, arg, indices);
1064+
auto dims = rewriter.create<mlir::LLVM::LoadOp>(loc, ptr);
1065+
auto lower = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, llvm_index_type, dims, pos0);
1066+
auto upper = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, llvm_index_type, dims, pos1);
1067+
mapping.map(old_entry.getArgument(i), from_llvm_index(lower));
1068+
mapping.map(old_entry.getArgument(i + num_loops), from_llvm_index(upper));
1069+
}
1070+
mapping.map(old_entry.getArgument(2 * num_loops), entry->getArgument(1)); // thread index
10211071
for (auto arg : context_constants)
10221072
{
10231073
rewriter.clone(*arg, mapping);
10241074
}
1025-
auto context_ptr = rewriter.create<mlir::LLVM::BitcastOp>(loc, context_ptr_type, entry->getArgument(3));
1075+
auto context_ptr = rewriter.create<mlir::LLVM::BitcastOp>(loc, context_ptr_type, entry->getArgument(2));
10261076
auto zero = rewriter.create<mlir::LLVM::ConstantOp>(loc, llvm_i32_type, rewriter.getI32IntegerAttr(0));
10271077
for (auto it : llvm::enumerate(context_vars))
10281078
{
10291079
auto index = it.index();
10301080
auto old_val = it.value();
1031-
mlir::Value indices[] = {
1081+
const mlir::Value indices[] = {
10321082
zero,
10331083
rewriter.create<mlir::LLVM::ConstantOp>(loc, llvm_i32_type, rewriter.getI32IntegerAttr(static_cast<int32_t>(index)))
10341084
};
@@ -1060,21 +1110,39 @@ struct LowerParallel : public mlir::OpRewritePattern<plier::ParallelOp>
10601110
{
10611111
return sym;
10621112
}
1063-
mlir::Type args[] = {
1064-
index_type, // lower bound
1065-
index_type, // upper bound
1066-
index_type, // step
1067-
func_type,
1068-
void_ptr_type
1113+
const mlir::Type args[] = {
1114+
input_range_ptr, // bounds
1115+
index_type, // num_loops
1116+
func_type, // func
1117+
void_ptr_type // context
10691118
};
1070-
auto func_type = mlir::FunctionType::get(op.getContext(), args, {});
1071-
return plier::add_function(rewriter, mod, func_name, func_type);
1119+
auto parallel_func_type = mlir::FunctionType::get(op.getContext(), args, {});
1120+
return plier::add_function(rewriter, mod, func_name, parallel_func_type);
10721121
}();
10731122
auto func_addr = rewriter.create<mlir::ConstantOp>(loc, func_type, rewriter.getSymbolRefAttr(outlined_func));
1074-
mlir::Value pf_args[] = {
1075-
op.lowerBound(),
1076-
op.upperBound(),
1077-
op.step(),
1123+
1124+
auto num_loops_var = rewriter.create<mlir::ConstantIndexOp>(loc, num_loops);
1125+
auto input_ranges = rewriter.create<mlir::LLVM::AllocaOp>(loc, input_range_ptr, to_llvm_index(num_loops_var), 0);
1126+
for (unsigned i = 0; i < num_loops; ++i)
1127+
{
1128+
mlir::Value input_range = rewriter.create<mlir::LLVM::UndefOp>(loc, input_range_type);
1129+
auto insert = [&](mlir::Value val, unsigned index)
1130+
{
1131+
input_range = rewriter.create<mlir::LLVM::InsertValueOp>(loc, input_range, val, rewriter.getI64ArrayAttr(index));
1132+
};
1133+
insert(to_llvm_index(op.lowerBounds()[i]), 0);
1134+
insert(to_llvm_index(op.upperBounds()[i]), 1);
1135+
insert(to_llvm_index(op.steps()[i]), 2);
1136+
const mlir::Value indices[] = {
1137+
rewriter.create<mlir::LLVM::ConstantOp>(loc, llvm_i32_type, rewriter.getI32IntegerAttr(static_cast<int>(i)))
1138+
};
1139+
auto ptr = rewriter.create<mlir::LLVM::GEPOp>(loc, input_range_ptr, input_ranges, indices);
1140+
rewriter.create<mlir::LLVM::StoreOp>(loc, input_range, ptr);
1141+
}
1142+
1143+
const mlir::Value pf_args[] = {
1144+
input_ranges,
1145+
num_loops_var,
10781146
func_addr,
10791147
context_abstract
10801148
};
@@ -1226,6 +1294,7 @@ void populate_lower_to_llvm_pipeline(mlir::OpPassManager& pm)
12261294
{
12271295
pm.addPass(std::make_unique<LowerParallelToCFGPass>());
12281296
pm.addPass(mlir::createLowerToCFGPass());
1297+
pm.addPass(mlir::createCanonicalizerPass());
12291298
// pm.addPass(std::make_unique<CheckForPlierTypes>());
12301299
pm.addNestedPass<mlir::FuncOp>(std::make_unique<PreLLVMLowering>());
12311300
pm.addPass(std::make_unique<LLVMLoweringPass>(getLLVMOptions()));

mlir-compiler/mlir-compiler/src/pipelines/parallel_to_tbb.cpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,6 @@ struct ParallelToTbb : public mlir::OpRewritePattern<mlir::scf::ParallelOp>
4646
{
4747
return mlir::failure();
4848
}
49-
if (op.getNumLoops() != 1)
50-
{
51-
return mlir::failure();
52-
}
5349
bool need_parallel = op->hasAttr(plier::attributes::getParallelName()) ||
5450
!op->getParentOfType<mlir::scf::ParallelOp>();
5551
if (!need_parallel)
@@ -109,10 +105,10 @@ struct ParallelToTbb : public mlir::OpRewritePattern<mlir::scf::ParallelOp>
109105
rewriter.create<mlir::scf::ForOp>(loc, reduce_lower_bound, reduce_upper_bound, reduce_step, llvm::None, reduce_init_body_builder);
110106

111107
auto& old_body = op.getLoopBody().front();
112-
auto orig_lower_bound = op.lowerBound().front();
113-
auto orig_upper_bound = op.upperBound().front();
114-
auto orig_step = op.step().front();
115-
auto body_builder = [&](mlir::OpBuilder &builder, ::mlir::Location loc, mlir::Value lower_bound, mlir::Value upper_bound, mlir::Value thread_index)
108+
auto orig_lower_bound = op.lowerBound();
109+
auto orig_upper_bound = op.upperBound();
110+
auto orig_step = op.step();
111+
auto body_builder = [&](mlir::OpBuilder &builder, ::mlir::Location loc, mlir::ValueRange lower_bound, mlir::ValueRange upper_bound, mlir::Value thread_index)
116112
{
117113
llvm::SmallVector<mlir::Value> initVals(op.initVals().size());
118114
for (auto it : llvm::enumerate(op.initVals()))

mlir-compiler/plier/include/plier/PlierOps.td

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -215,21 +215,26 @@ def GetattrOp : Plier_Op<"getattr", [NoSideEffect]> {
215215
}
216216

217217
def ParallelOp : Plier_Op<"parallel",
218-
[DeclareOpInterfaceMethods<LoopLikeOpInterface>,
218+
[AttrSizedOperandSegments,
219+
DeclareOpInterfaceMethods<LoopLikeOpInterface>,
219220
SingleBlockImplicitTerminator<"plier::YieldOp">,
220221
RecursiveSideEffects]> {
221222

222-
let arguments = (ins Index:$lowerBound,
223-
Index:$upperBound,
224-
Index:$step);
223+
let arguments = (ins Variadic<Index>:$lowerBounds,
224+
Variadic<Index>:$upperBounds,
225+
Variadic<Index>:$steps);
225226
let regions = (region SizedRegion<1>:$region);
226227

227228
let skipDefaultBuilders = 1;
228229
let builders = [
229-
OpBuilderDAG<(ins "::mlir::Value":$lowerBound, "::mlir::Value":$upperBound, "::mlir::Value":$step,
230-
CArg<"::mlir::function_ref<void(::mlir::OpBuilder &, ::mlir::Location, ::mlir::Value, ::mlir::Value, ::mlir::Value)>",
230+
OpBuilderDAG<(ins "::mlir::ValueRange":$lowerBounds, "::mlir::ValueRange":$upperBounds, "::mlir::ValueRange":$steps,
231+
CArg<"::mlir::function_ref<void(::mlir::OpBuilder &, ::mlir::Location, ::mlir::ValueRange, ::mlir::ValueRange, ::mlir::Value)>",
231232
"nullptr">)>
232233
];
234+
235+
let extraClassDeclaration = [{
236+
unsigned getNumLoops() { return steps().size(); }
237+
}];
233238
}
234239

235240
def YieldOp : Plier_Op<"yield", [NoSideEffect, ReturnLike, Terminator,

mlir-compiler/plier/src/dialect.cpp

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -354,23 +354,33 @@ bool ParallelOp::isDefinedOutsideOfLoop(mlir::Value value)
354354

355355
void ParallelOp::build(
356356
mlir::OpBuilder &odsBuilder, mlir::OperationState &odsState,
357-
mlir::Value lowerBound, mlir::Value upperBound, mlir::Value step,
358-
mlir::function_ref<void(mlir::OpBuilder &, mlir::Location, mlir::Value,
359-
mlir::Value, mlir::Value)> bodyBuilder) {
360-
odsState.addOperands({lowerBound, upperBound, step});
357+
mlir::ValueRange lowerBounds, mlir::ValueRange upperBounds, mlir::ValueRange steps,
358+
mlir::function_ref<void(mlir::OpBuilder &, mlir::Location, mlir::ValueRange,
359+
mlir::ValueRange, mlir::Value)> bodyBuilder) {
360+
assert(lowerBounds.size() == upperBounds.size());
361+
assert(lowerBounds.size() == steps.size());
362+
odsState.addOperands(lowerBounds);
363+
odsState.addOperands(upperBounds);
364+
odsState.addOperands(steps);
365+
odsState.addAttribute(
366+
ParallelOp::getOperandSegmentSizeAttr(),
367+
odsBuilder.getI32VectorAttr({static_cast<int32_t>(lowerBounds.size()),
368+
static_cast<int32_t>(upperBounds.size()),
369+
static_cast<int32_t>(steps.size())}));
361370
auto bodyRegion = odsState.addRegion();
362-
bodyRegion->push_back(new mlir::Block);
363-
auto& bodyBlock = bodyRegion->front();
364-
bodyBlock.addArgument(odsBuilder.getIndexType()); // lower bound
365-
bodyBlock.addArgument(odsBuilder.getIndexType()); // upper bound
366-
bodyBlock.addArgument(odsBuilder.getIndexType()); // thread index
371+
auto count = lowerBounds.size();
372+
mlir::OpBuilder::InsertionGuard guard(odsBuilder);
373+
llvm::SmallVector<mlir::Type> argTypes(count * 2 + 1, odsBuilder.getIndexType());
374+
auto *bodyBlock = odsBuilder.createBlock(bodyRegion, {}, argTypes);
367375

368376
if (bodyBuilder)
369377
{
370-
mlir::OpBuilder::InsertionGuard guard(odsBuilder);
371-
odsBuilder.setInsertionPointToStart(&bodyBlock);
372-
bodyBuilder(odsBuilder, odsState.location, bodyBlock.getArgument(0),
373-
bodyBlock.getArgument(1), bodyBlock.getArgument(2));
378+
odsBuilder.setInsertionPointToStart(bodyBlock);
379+
auto args = bodyBlock->getArguments();
380+
bodyBuilder(odsBuilder, odsState.location,
381+
args.take_front(count),
382+
args.drop_front(count).take_front(count),
383+
args.back());
374384
ParallelOp::ensureTerminator(*bodyRegion, odsBuilder, odsState.location);
375385
}
376386
}

numba/mlir/tests/test_numpy.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,5 +295,14 @@ def py_func(a, b):
295295
for a, b in itertools.product(test_data, test_data):
296296
assert_equal(py_func(a,b), jit_func(a,b))
297297

298+
def test_parallel(self):
299+
def py_func(a, b):
300+
return np.add(a, b)
301+
302+
jit_func = njit(py_func, parallel=True)
303+
arr = np.asarray([[[1,2,3],[4,5,6]],
304+
[[1,2,3],[4,5,6]]])
305+
assert_equal(py_func(arr,arr), jit_func(arr,arr))
306+
298307
if __name__ == '__main__':
299308
unittest.main()

0 commit comments

Comments
 (0)