Skip to content
This repository was archived by the owner on Jan 25, 2023. It is now read-only.

Commit c22682f

Browse files
author
Ivan Butygin
authored
[MLIR] Numpy empty and sum axis (#187)
* store context in Var * rework shape accessor * dtype accessor rework * remove unused code * refactor shape * remove unused code * fix setitem lowering * numpy empty * numpy.sum * some kwargs support * linlag resolver kwargs support * linalg resolver some literal support * work on linalg resolver * add symbolDCE pass * numpy sum axis support
1 parent 0d19002 commit c22682f

File tree

9 files changed

+456
-148
lines changed

9 files changed

+456
-148
lines changed

mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,12 @@ bool is_int(mlir::Type type)
133133
return type.isa<mlir::IntegerType>();
134134
}
135135

136-
mlir::LogicalResult lower_prange(plier::PyCallOp op, llvm::ArrayRef<mlir::Value> operands, mlir::PatternRewriter& rewriter)
136+
mlir::LogicalResult lower_prange(plier::PyCallOp op, llvm::ArrayRef<mlir::Value> operands, llvm::ArrayRef<std::pair<llvm::StringRef, mlir::Value>> kwargs, mlir::PatternRewriter& rewriter)
137137
{
138+
if (!kwargs.empty())
139+
{
140+
return mlir::failure();
141+
}
138142
if ((operands.size() < 1 || operands.size() > 3) ||
139143
!llvm::all_of(operands, [](mlir::Value val) { return is_int(val.getType());}))
140144
{
@@ -177,21 +181,21 @@ struct CallLowerer
177181
{
178182
mlir::LogicalResult operator()(
179183
plier::PyCallOp op, llvm::StringRef name, llvm::ArrayRef<mlir::Value> args,
180-
mlir::PatternRewriter& rewriter)
184+
llvm::ArrayRef<std::pair<llvm::StringRef, mlir::Value>> kwargs, mlir::PatternRewriter& rewriter)
181185
{
182-
using func_t = mlir::LogicalResult(*)(plier::PyCallOp, llvm::ArrayRef<mlir::Value>, mlir::PatternRewriter&);
186+
using func_t = mlir::LogicalResult(*)(plier::PyCallOp, llvm::ArrayRef<mlir::Value>, llvm::ArrayRef<std::pair<llvm::StringRef, mlir::Value>>, mlir::PatternRewriter&);
183187
std::pair<llvm::StringRef, func_t> handlers[] = {
184188
{"numba.prange", lower_prange},
185189
};
186190
for (auto& handler : handlers)
187191
{
188192
if (handler.first == name)
189193
{
190-
return handler.second(op, args, rewriter);
194+
return handler.second(op, args, kwargs, rewriter);
191195
}
192196
}
193197

194-
if (auto result = linalg_resolver.rewrite(name, op.getLoc(), rewriter, args))
198+
if (auto result = linalg_resolver.rewrite(name, op.getLoc(), rewriter, args, kwargs))
195199
{
196200
assert(result->size() == op->getNumResults());
197201
rerun_std_pipeline(op);
@@ -206,7 +210,7 @@ struct CallLowerer
206210
return mlir::success();
207211
}
208212

209-
if (name == "len" && check_numpy_args(args, 1))
213+
if (name == "len" && check_numpy_args(args, 1) && kwargs.empty())
210214
{
211215
auto loc = op.getLoc();
212216
mlir::Value dim = rewriter.create<mlir::DimOp>(loc, args[0], 0);
@@ -219,7 +223,6 @@ struct CallLowerer
219223
}
220224

221225
private:
222-
223226
PyLinalgResolver linalg_resolver;
224227
};
225228

@@ -436,7 +439,7 @@ struct SetitemOpLowering : public mlir::OpRewritePattern<T>
436439
mlir::OpBuilder::InsertionGuard g(rewriter);
437440
if (auto parent_op = target.getDefiningOp())
438441
{
439-
rewriter.setInsertionPoint(parent_op);
442+
rewriter.setInsertionPointAfter(parent_op);
440443
}
441444
else
442445
{
@@ -456,6 +459,7 @@ struct SetitemOpLowering : public mlir::OpRewritePattern<T>
456459
}
457460
else
458461
{
462+
mlir::OpBuilder::InsertionGuard g(rewriter);
459463
rewriter.setInsertionPoint(use_op);
460464
auto new_val = rewriter.create<mlir::TensorLoadOp>(use_op->getLoc(), memref);
461465
rewriter.updateRootInPlace(use_op, [&]()
@@ -602,6 +606,7 @@ struct LowerLinalgPass :
602606
mlir::DialectRegistry &registry) const override
603607
{
604608
registry.insert<mlir::StandardOpsDialect>();
609+
registry.insert<mlir::tensor::TensorDialect>();
605610
registry.insert<mlir::linalg::LinalgDialect>();
606611
registry.insert<mlir::scf::SCFDialect>();
607612
registry.insert<mlir::AffineDialect>();
@@ -686,6 +691,7 @@ void PostLinalgOptPass::runOnOperation()
686691
void populate_plier_to_linalg_gen_pipeline(mlir::OpPassManager& pm)
687692
{
688693
pm.addPass(std::make_unique<PlierToLinalgPass>());
694+
pm.addPass(mlir::createSymbolDCEPass());
689695
}
690696

691697
void populate_plier_to_linalg_opt_pipeline(mlir::OpPassManager& pm)
@@ -708,6 +714,7 @@ void populate_plier_to_linalg_opt_pipeline(mlir::OpPassManager& pm)
708714

709715
pm.addPass(std::make_unique<LowerLinalgPass>());
710716
pm.addPass(std::make_unique<PostLinalgOptPass>());
717+
pm.addPass(mlir::createSymbolDCEPass());
711718
}
712719
}
713720

mlir-compiler/mlir-compiler/src/pipelines/plier_to_std.cpp

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1156,8 +1156,12 @@ struct FoldTupleGetitem : public mlir::OpRewritePattern<Op>
11561156
}
11571157
};
11581158

1159-
mlir::LogicalResult lower_range(plier::PyCallOp op, llvm::ArrayRef<mlir::Value> operands, mlir::PatternRewriter& rewriter)
1159+
mlir::LogicalResult lower_range(plier::PyCallOp op, llvm::ArrayRef<mlir::Value> operands, llvm::ArrayRef<std::pair<llvm::StringRef, mlir::Value>> kwargs, mlir::PatternRewriter& rewriter)
11601160
{
1161+
if (!kwargs.empty())
1162+
{
1163+
return mlir::failure();
1164+
}
11611165
if ((operands.size() < 1 || operands.size() > 3) ||
11621166
!llvm::all_of(operands, [](mlir::Value val) { return is_int(val.getType());}))
11631167
{
@@ -1191,8 +1195,12 @@ mlir::LogicalResult lower_range(plier::PyCallOp op, llvm::ArrayRef<mlir::Value>
11911195
return mlir::success();
11921196
}
11931197

1194-
mlir::LogicalResult lower_len(plier::PyCallOp op, llvm::ArrayRef<mlir::Value> operands, mlir::PatternRewriter& rewriter)
1198+
mlir::LogicalResult lower_len(plier::PyCallOp op, llvm::ArrayRef<mlir::Value> operands, llvm::ArrayRef<std::pair<llvm::StringRef, mlir::Value>> kwargs, mlir::PatternRewriter& rewriter)
11951199
{
1200+
if (!kwargs.empty())
1201+
{
1202+
return mlir::failure();
1203+
}
11961204
if (operands.size() != 1)
11971205
{
11981206
return mlir::failure();
@@ -1210,8 +1218,12 @@ mlir::LogicalResult lower_len(plier::PyCallOp op, llvm::ArrayRef<mlir::Value> op
12101218
return mlir::success();
12111219
}
12121220

1213-
mlir::LogicalResult lower_bool_cast(plier::PyCallOp op, llvm::ArrayRef<mlir::Value> operands, mlir::PatternRewriter& rewriter)
1221+
mlir::LogicalResult lower_bool_cast(plier::PyCallOp op, llvm::ArrayRef<mlir::Value> operands, llvm::ArrayRef<std::pair<llvm::StringRef, mlir::Value>> kwargs, mlir::PatternRewriter& rewriter)
12141222
{
1223+
if (!kwargs.empty())
1224+
{
1225+
return mlir::failure();
1226+
}
12151227
if (operands.size() != 1)
12161228
{
12171229
return mlir::failure();
@@ -1250,8 +1262,12 @@ mlir::FuncOp get_lib_symbol(
12501262

12511263
mlir::LogicalResult lower_math_func(
12521264
plier::PyCallOp op, llvm::StringRef name, llvm::ArrayRef<mlir::Value> args,
1253-
mlir::PatternRewriter& rewriter)
1265+
llvm::ArrayRef<std::pair<llvm::StringRef, mlir::Value>> kwargs, mlir::PatternRewriter& rewriter)
12541266
{
1267+
if (!kwargs.empty())
1268+
{
1269+
return mlir::failure();
1270+
}
12551271
auto ret_type = map_plier_type(op.getType());
12561272
auto valid_type = [&](mlir::Type type)
12571273
{
@@ -1285,14 +1301,14 @@ mlir::LogicalResult lower_math_func(
12851301
struct CallLowerer
12861302
{
12871303
mlir::LogicalResult operator()(plier::PyCallOp op, llvm::StringRef name,
1288-
llvm::ArrayRef<mlir::Value> args, mlir::PatternRewriter& rewriter)
1304+
llvm::ArrayRef<mlir::Value> args, llvm::ArrayRef<std::pair<llvm::StringRef, mlir::Value>> kwargs, mlir::PatternRewriter& rewriter)
12891305
{
1290-
if (mlir::succeeded(lower_math_func(op, name, args, rewriter)))
1306+
if (mlir::succeeded(lower_math_func(op, name, args, kwargs, rewriter)))
12911307
{
12921308
return mlir::success();
12931309
}
12941310

1295-
using func_t = mlir::LogicalResult(*)(plier::PyCallOp, llvm::ArrayRef<mlir::Value>, mlir::PatternRewriter&);
1311+
using func_t = mlir::LogicalResult(*)(plier::PyCallOp, llvm::ArrayRef<mlir::Value>, llvm::ArrayRef<std::pair<llvm::StringRef, mlir::Value>>, mlir::PatternRewriter&);
12961312
std::pair<llvm::StringRef, func_t> handlers[] = {
12971313
{"bool", lower_bool_cast},
12981314
{"range", lower_range},
@@ -1302,7 +1318,7 @@ struct CallLowerer
13021318
{
13031319
if (handler.first == name)
13041320
{
1305-
return handler.second(op, args, rewriter);
1321+
return handler.second(op, args, kwargs, rewriter);
13061322
}
13071323
}
13081324

0 commit comments

Comments
 (0)